__init__.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. from __future__ import annotations
  2. import os
  3. import logging
  4. from . import debug, version
  5. from .models import Model
  6. from .client import Client, AsyncClient
  7. from .typing import Messages, CreateResult, AsyncResult, Union
  8. from .errors import StreamNotSupportedError, ModelNotAllowedError
  9. from .cookies import get_cookies, set_cookies
  10. from .providers.types import ProviderType
  11. from .providers.base_provider import AsyncGeneratorProvider
  12. from .client.service import get_model_and_provider, get_last_provider
  13. #Configure "g4f" logger
  14. logger = logging.getLogger(__name__)
  15. log_handler = logging.StreamHandler()
  16. log_handler.setFormatter(logging.Formatter(logging.BASIC_FORMAT))
  17. logger.addHandler(log_handler)
  18. logger.setLevel(logging.ERROR)
  19. class ChatCompletion:
  20. @staticmethod
  21. def create(model : Union[Model, str],
  22. messages : Messages,
  23. provider : Union[ProviderType, str, None] = None,
  24. stream : bool = False,
  25. auth : Union[str, None] = None,
  26. ignored : list[str] = None,
  27. ignore_working: bool = False,
  28. ignore_stream: bool = False,
  29. patch_provider: callable = None,
  30. **kwargs) -> Union[CreateResult, str]:
  31. model, provider = get_model_and_provider(
  32. model, provider, stream,
  33. ignored, ignore_working,
  34. ignore_stream or kwargs.get("ignore_stream_and_auth")
  35. )
  36. if auth is not None:
  37. kwargs['auth'] = auth
  38. if "proxy" not in kwargs:
  39. proxy = os.environ.get("G4F_PROXY")
  40. if proxy:
  41. kwargs['proxy'] = proxy
  42. if patch_provider:
  43. provider = patch_provider(provider)
  44. result = provider.create_completion(model, messages, stream=stream, **kwargs)
  45. return result if stream else ''.join([str(chunk) for chunk in result])
  46. @staticmethod
  47. def create_async(model : Union[Model, str],
  48. messages : Messages,
  49. provider : Union[ProviderType, str, None] = None,
  50. stream : bool = False,
  51. ignored : list[str] = None,
  52. ignore_working: bool = False,
  53. patch_provider: callable = None,
  54. **kwargs) -> Union[AsyncResult, str]:
  55. model, provider = get_model_and_provider(model, provider, False, ignored, ignore_working)
  56. if stream:
  57. if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider):
  58. return provider.create_async_generator(model, messages, **kwargs)
  59. raise StreamNotSupportedError(f'{provider.__name__} does not support "stream" argument in "create_async"')
  60. if patch_provider:
  61. provider = patch_provider(provider)
  62. return provider.create_async(model, messages, **kwargs)
  63. class Completion:
  64. @staticmethod
  65. def create(model : Union[Model, str],
  66. prompt : str,
  67. provider : Union[ProviderType, None] = None,
  68. stream : bool = False,
  69. ignored : list[str] = None, **kwargs) -> Union[CreateResult, str]:
  70. allowed_models = [
  71. 'code-davinci-002',
  72. 'text-ada-001',
  73. 'text-babbage-001',
  74. 'text-curie-001',
  75. 'text-davinci-002',
  76. 'text-davinci-003'
  77. ]
  78. if model not in allowed_models:
  79. raise ModelNotAllowedError(f'Can\'t use {model} with Completion.create()')
  80. model, provider = get_model_and_provider(model, provider, stream, ignored)
  81. result = provider.create_completion(model, [{"role": "user", "content": prompt}], stream=stream, **kwargs)
  82. return result if stream else ''.join(result)