__init__.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. from __future__ import annotations
  2. import os
  3. from .errors import *
  4. from .models import Model, ModelUtils, _all_models
  5. from .Provider import BaseProvider, AsyncGeneratorProvider, RetryProvider, ProviderUtils
  6. from .typing import Messages, CreateResult, AsyncResult, Union, List
  7. from . import debug
  8. def get_model_and_provider(model : Union[Model, str],
  9. provider : Union[type[BaseProvider], str, None],
  10. stream : bool,
  11. ignored : List[str] = None,
  12. ignore_working: bool = False,
  13. ignore_stream: bool = False) -> tuple[Model, type[BaseProvider]]:
  14. if debug.version_check:
  15. debug.version_check = False
  16. debug.check_pypi_version()
  17. if isinstance(provider, str):
  18. if provider in ProviderUtils.convert:
  19. provider = ProviderUtils.convert[provider]
  20. else:
  21. raise ProviderNotFoundError(f'Provider not found: {provider}')
  22. if isinstance(model, str):
  23. if model in ModelUtils.convert:
  24. model = ModelUtils.convert[model]
  25. else:
  26. raise ModelNotFoundError(f'The model: {model} does not exist')
  27. if not provider:
  28. provider = model.best_provider
  29. if isinstance(provider, RetryProvider) and ignored:
  30. provider.providers = [p for p in provider.providers if p.__name__ not in ignored]
  31. if not provider:
  32. raise ProviderNotFoundError(f'No provider found for model: {model}')
  33. if not provider.working and not ignore_working:
  34. raise ProviderNotWorkingError(f'{provider.__name__} is not working')
  35. if not ignore_stream and not provider.supports_stream and stream:
  36. raise StreamNotSupportedError(f'{provider.__name__} does not support "stream" argument')
  37. if debug.logging:
  38. print(f'Using {provider.__name__} provider')
  39. return model, provider
  40. class ChatCompletion:
  41. @staticmethod
  42. def create(model : Union[Model, str],
  43. messages : Messages,
  44. provider : Union[type[BaseProvider], str, None] = None,
  45. stream : bool = False,
  46. auth : Union[str, None] = None,
  47. ignored : List[str] = None,
  48. ignore_working: bool = False,
  49. ignore_stream_and_auth: bool = False,
  50. **kwargs) -> Union[CreateResult, str]:
  51. model, provider = get_model_and_provider(model, provider, stream, ignored, ignore_working, ignore_stream_and_auth)
  52. if not ignore_stream_and_auth and provider.needs_auth and not auth:
  53. raise AuthenticationRequiredError(f'{provider.__name__} requires authentication (use auth=\'cookie or token or jwt ...\' param)')
  54. if auth:
  55. kwargs['auth'] = auth
  56. if "proxy" not in kwargs:
  57. proxy = os.environ.get("G4F_PROXY")
  58. if proxy:
  59. kwargs['proxy'] = proxy
  60. result = provider.create_completion(model.name, messages, stream, **kwargs)
  61. return result if stream else ''.join(result)
  62. @staticmethod
  63. async def create_async(model : Union[Model, str],
  64. messages : Messages,
  65. provider : Union[type[BaseProvider], str, None] = None,
  66. stream : bool = False,
  67. ignored : List[str] = None,
  68. **kwargs) -> Union[AsyncResult, str]:
  69. model, provider = get_model_and_provider(model, provider, False, ignored)
  70. if stream:
  71. if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider):
  72. return await provider.create_async_generator(model.name, messages, **kwargs)
  73. raise StreamNotSupportedError(f'{provider.__name__} does not support "stream" argument in "create_async"')
  74. return await provider.create_async(model.name, messages, **kwargs)
  75. class Completion:
  76. @staticmethod
  77. def create(model : Union[Model, str],
  78. prompt : str,
  79. provider : Union[type[BaseProvider], None] = None,
  80. stream : bool = False,
  81. ignored : List[str] = None, **kwargs) -> Union[CreateResult, str]:
  82. allowed_models = [
  83. 'code-davinci-002',
  84. 'text-ada-001',
  85. 'text-babbage-001',
  86. 'text-curie-001',
  87. 'text-davinci-002',
  88. 'text-davinci-003'
  89. ]
  90. if model not in allowed_models:
  91. raise ModelNotAllowed(f'Can\'t use {model} with Completion.create()')
  92. model, provider = get_model_and_provider(model, provider, stream, ignored)
  93. result = provider.create_completion(model.name, [{"role": "user", "content": prompt}], stream, **kwargs)
  94. return result if stream else ''.join(result)