__init__.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. from __future__ import annotations
  2. import os
  3. from . import debug, version
  4. from .models import Model
  5. from .client import Client, AsyncClient
  6. from .typing import Messages, CreateResult, AsyncResult, Union
  7. from .errors import StreamNotSupportedError, ModelNotAllowedError
  8. from .cookies import get_cookies, set_cookies
  9. from .providers.types import ProviderType
  10. from .providers.base_provider import AsyncGeneratorProvider
  11. from .client.service import get_model_and_provider, get_last_provider
  12. class ChatCompletion:
  13. @staticmethod
  14. def create(model : Union[Model, str],
  15. messages : Messages,
  16. provider : Union[ProviderType, str, None] = None,
  17. stream : bool = False,
  18. auth : Union[str, None] = None,
  19. ignored : list[str] = None,
  20. ignore_working: bool = False,
  21. ignore_stream: bool = False,
  22. patch_provider: callable = None,
  23. **kwargs) -> Union[CreateResult, str]:
  24. model, provider = get_model_and_provider(
  25. model, provider, stream,
  26. ignored, ignore_working,
  27. ignore_stream or kwargs.get("ignore_stream_and_auth")
  28. )
  29. if auth is not None:
  30. kwargs['auth'] = auth
  31. if "proxy" not in kwargs:
  32. proxy = os.environ.get("G4F_PROXY")
  33. if proxy:
  34. kwargs['proxy'] = proxy
  35. if patch_provider:
  36. provider = patch_provider(provider)
  37. result = provider.create_completion(model, messages, stream=stream, **kwargs)
  38. return result if stream else ''.join([str(chunk) for chunk in result])
  39. @staticmethod
  40. def create_async(model : Union[Model, str],
  41. messages : Messages,
  42. provider : Union[ProviderType, str, None] = None,
  43. stream : bool = False,
  44. ignored : list[str] = None,
  45. ignore_working: bool = False,
  46. patch_provider: callable = None,
  47. **kwargs) -> Union[AsyncResult, str]:
  48. model, provider = get_model_and_provider(model, provider, False, ignored, ignore_working)
  49. if stream:
  50. if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider):
  51. return provider.create_async_generator(model, messages, **kwargs)
  52. raise StreamNotSupportedError(f'{provider.__name__} does not support "stream" argument in "create_async"')
  53. if patch_provider:
  54. provider = patch_provider(provider)
  55. return provider.create_async(model, messages, **kwargs)
  56. class Completion:
  57. @staticmethod
  58. def create(model : Union[Model, str],
  59. prompt : str,
  60. provider : Union[ProviderType, None] = None,
  61. stream : bool = False,
  62. ignored : list[str] = None, **kwargs) -> Union[CreateResult, str]:
  63. allowed_models = [
  64. 'code-davinci-002',
  65. 'text-ada-001',
  66. 'text-babbage-001',
  67. 'text-curie-001',
  68. 'text-davinci-002',
  69. 'text-davinci-003'
  70. ]
  71. if model not in allowed_models:
  72. raise ModelNotAllowedError(f'Can\'t use {model} with Completion.create()')
  73. model, provider = get_model_and_provider(model, provider, stream, ignored)
  74. result = provider.create_completion(model, [{"role": "user", "content": prompt}], stream=stream, **kwargs)
  75. return result if stream else ''.join(result)