base_provider.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. from __future__ import annotations
  2. import sys
  3. import asyncio
  4. from asyncio import AbstractEventLoop
  5. from concurrent.futures import ThreadPoolExecutor
  6. from abc import abstractmethod
  7. from inspect import signature, Parameter
  8. from ..typing import CreateResult, AsyncResult, Messages
  9. from .types import BaseProvider
  10. from .asyncio import get_running_loop, to_sync_generator
  11. from .response import FinishReason, BaseConversation, SynthesizeData
  12. from ..errors import ModelNotSupportedError
  13. from .. import debug
  14. # Set Windows event loop policy for better compatibility with asyncio and curl_cffi
  15. if sys.platform == 'win32':
  16. try:
  17. from curl_cffi import aio
  18. if not hasattr(aio, "_get_selector"):
  19. if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy):
  20. asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
  21. except ImportError:
  22. pass
  23. class AbstractProvider(BaseProvider):
  24. """
  25. Abstract class for providing asynchronous functionality to derived classes.
  26. """
  27. @classmethod
  28. async def create_async(
  29. cls,
  30. model: str,
  31. messages: Messages,
  32. *,
  33. loop: AbstractEventLoop = None,
  34. executor: ThreadPoolExecutor = None,
  35. **kwargs
  36. ) -> str:
  37. """
  38. Asynchronously creates a result based on the given model and messages.
  39. Args:
  40. cls (type): The class on which this method is called.
  41. model (str): The model to use for creation.
  42. messages (Messages): The messages to process.
  43. loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
  44. executor (ThreadPoolExecutor, optional): The executor for running async tasks. Defaults to None.
  45. **kwargs: Additional keyword arguments.
  46. Returns:
  47. str: The created result as a string.
  48. """
  49. loop = loop or asyncio.get_running_loop()
  50. def create_func() -> str:
  51. chunks = [str(chunk) for chunk in cls.create_completion(model, messages, False, **kwargs) if chunk]
  52. if chunks:
  53. return "".join(chunks)
  54. return await asyncio.wait_for(
  55. loop.run_in_executor(executor, create_func),
  56. timeout=kwargs.get("timeout")
  57. )
  58. @classmethod
  59. def get_parameters(cls) -> dict[str, Parameter]:
  60. return {name: parameter for name, parameter in signature(
  61. cls.create_async_generator if issubclass(cls, AsyncGeneratorProvider) else
  62. cls.create_async if issubclass(cls, AsyncProvider) else
  63. cls.create_completion
  64. ).parameters.items() if name not in ["kwargs", "model", "messages"]
  65. and (name != "stream" or cls.supports_stream)}
  66. @classmethod
  67. @property
  68. def params(cls) -> str:
  69. """
  70. Returns the parameters supported by the provider.
  71. Args:
  72. cls (type): The class on which this property is called.
  73. Returns:
  74. str: A string listing the supported parameters.
  75. """
  76. def get_type_name(annotation: type) -> str:
  77. return annotation.__name__ if hasattr(annotation, "__name__") else str(annotation)
  78. args = ""
  79. for name, param in cls.get_parameters().items():
  80. args += f"\n {name}"
  81. args += f": {get_type_name(param.annotation)}" if param.annotation is not Parameter.empty else ""
  82. default_value = f'"{param.default}"' if isinstance(param.default, str) else param.default
  83. args += f" = {default_value}" if param.default is not Parameter.empty else ""
  84. args += ","
  85. return f"g4f.Provider.{cls.__name__} supports: ({args}\n)"
  86. class AsyncProvider(AbstractProvider):
  87. """
  88. Provides asynchronous functionality for creating completions.
  89. """
  90. @classmethod
  91. def create_completion(
  92. cls,
  93. model: str,
  94. messages: Messages,
  95. stream: bool = False,
  96. **kwargs
  97. ) -> CreateResult:
  98. """
  99. Creates a completion result synchronously.
  100. Args:
  101. cls (type): The class on which this method is called.
  102. model (str): The model to use for creation.
  103. messages (Messages): The messages to process.
  104. stream (bool): Indicates whether to stream the results. Defaults to False.
  105. loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
  106. **kwargs: Additional keyword arguments.
  107. Returns:
  108. CreateResult: The result of the completion creation.
  109. """
  110. get_running_loop(check_nested=False)
  111. yield asyncio.run(cls.create_async(model, messages, **kwargs))
  112. @staticmethod
  113. @abstractmethod
  114. async def create_async(
  115. model: str,
  116. messages: Messages,
  117. **kwargs
  118. ) -> str:
  119. """
  120. Abstract method for creating asynchronous results.
  121. Args:
  122. model (str): The model to use for creation.
  123. messages (Messages): The messages to process.
  124. **kwargs: Additional keyword arguments.
  125. Raises:
  126. NotImplementedError: If this method is not overridden in derived classes.
  127. Returns:
  128. str: The created result as a string.
  129. """
  130. raise NotImplementedError()
  131. class AsyncGeneratorProvider(AsyncProvider):
  132. """
  133. Provides asynchronous generator functionality for streaming results.
  134. """
  135. supports_stream = True
  136. @classmethod
  137. def create_completion(
  138. cls,
  139. model: str,
  140. messages: Messages,
  141. stream: bool = True,
  142. **kwargs
  143. ) -> CreateResult:
  144. """
  145. Creates a streaming completion result synchronously.
  146. Args:
  147. cls (type): The class on which this method is called.
  148. model (str): The model to use for creation.
  149. messages (Messages): The messages to process.
  150. stream (bool): Indicates whether to stream the results. Defaults to True.
  151. loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
  152. **kwargs: Additional keyword arguments.
  153. Returns:
  154. CreateResult: The result of the streaming completion creation.
  155. """
  156. return to_sync_generator(
  157. cls.create_async_generator(model, messages, stream=stream, **kwargs)
  158. )
  159. @classmethod
  160. async def create_async(
  161. cls,
  162. model: str,
  163. messages: Messages,
  164. **kwargs
  165. ) -> str:
  166. """
  167. Asynchronously creates a result from a generator.
  168. Args:
  169. cls (type): The class on which this method is called.
  170. model (str): The model to use for creation.
  171. messages (Messages): The messages to process.
  172. **kwargs: Additional keyword arguments.
  173. Returns:
  174. str: The created result as a string.
  175. """
  176. return "".join([
  177. str(chunk) async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
  178. if chunk and not isinstance(chunk, (Exception, FinishReason, BaseConversation, SynthesizeData))
  179. ])
  180. @staticmethod
  181. @abstractmethod
  182. async def create_async_generator(
  183. model: str,
  184. messages: Messages,
  185. stream: bool = True,
  186. **kwargs
  187. ) -> AsyncResult:
  188. """
  189. Abstract method for creating an asynchronous generator.
  190. Args:
  191. model (str): The model to use for creation.
  192. messages (Messages): The messages to process.
  193. stream (bool): Indicates whether to stream the results. Defaults to True.
  194. **kwargs: Additional keyword arguments.
  195. Raises:
  196. NotImplementedError: If this method is not overridden in derived classes.
  197. Returns:
  198. AsyncResult: An asynchronous generator yielding results.
  199. """
  200. raise NotImplementedError()
  201. class ProviderModelMixin:
  202. default_model: str = None
  203. models: list[str] = []
  204. model_aliases: dict[str, str] = {}
  205. image_models: list = None
  206. last_model: str = None
  207. @classmethod
  208. def get_models(cls, **kwargs) -> list[str]:
  209. if not cls.models and cls.default_model is not None:
  210. return [cls.default_model]
  211. return cls.models
  212. @classmethod
  213. def get_model(cls, model: str, **kwargs) -> str:
  214. if not model and cls.default_model is not None:
  215. model = cls.default_model
  216. elif model in cls.model_aliases:
  217. model = cls.model_aliases[model]
  218. else:
  219. if model not in cls.get_models(**kwargs) and cls.models:
  220. raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
  221. cls.last_model = model
  222. debug.last_model = model
  223. return model