base_provider.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. from __future__ import annotations
  2. import sys
  3. import asyncio
  4. from asyncio import AbstractEventLoop
  5. from concurrent.futures import ThreadPoolExecutor
  6. from abc import abstractmethod
  7. from inspect import signature, Parameter
  8. from .helper import get_event_loop, get_cookies, format_prompt
  9. from ..typing import CreateResult, AsyncResult, Messages, Union
  10. from ..base_provider import BaseProvider
  11. if sys.version_info < (3, 10):
  12. NoneType = type(None)
  13. else:
  14. from types import NoneType
  15. # Change event loop policy on windows for curl_cffi
  16. if sys.platform == 'win32':
  17. if isinstance(
  18. asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy
  19. ):
  20. asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
  21. class AbstractProvider(BaseProvider):
  22. @classmethod
  23. async def create_async(
  24. cls,
  25. model: str,
  26. messages: Messages,
  27. *,
  28. loop: AbstractEventLoop = None,
  29. executor: ThreadPoolExecutor = None,
  30. **kwargs
  31. ) -> str:
  32. if not loop:
  33. loop = get_event_loop()
  34. def create_func() -> str:
  35. return "".join(cls.create_completion(
  36. model,
  37. messages,
  38. False,
  39. **kwargs
  40. ))
  41. return await asyncio.wait_for(
  42. loop.run_in_executor(
  43. executor,
  44. create_func
  45. ),
  46. timeout=kwargs.get("timeout", 0)
  47. )
  48. @classmethod
  49. @property
  50. def params(cls) -> str:
  51. if issubclass(cls, AsyncGeneratorProvider):
  52. sig = signature(cls.create_async_generator)
  53. elif issubclass(cls, AsyncProvider):
  54. sig = signature(cls.create_async)
  55. else:
  56. sig = signature(cls.create_completion)
  57. def get_type_name(annotation: type) -> str:
  58. if hasattr(annotation, "__name__"):
  59. annotation = annotation.__name__
  60. elif isinstance(annotation, NoneType):
  61. annotation = "None"
  62. return str(annotation)
  63. args = ""
  64. for name, param in sig.parameters.items():
  65. if name in ("self", "kwargs"):
  66. continue
  67. if name == "stream" and not cls.supports_stream:
  68. continue
  69. if args:
  70. args += ", "
  71. args += "\n"
  72. args += " " + name
  73. if name != "model" and param.annotation is not Parameter.empty:
  74. args += f": {get_type_name(param.annotation)}"
  75. if param.default == "":
  76. args += ' = ""'
  77. elif param.default is not Parameter.empty:
  78. args += f" = {param.default}"
  79. return f"g4f.Provider.{cls.__name__} supports: ({args}\n)"
  80. class AsyncProvider(AbstractProvider):
  81. @classmethod
  82. def create_completion(
  83. cls,
  84. model: str,
  85. messages: Messages,
  86. stream: bool = False,
  87. *,
  88. loop: AbstractEventLoop = None,
  89. **kwargs
  90. ) -> CreateResult:
  91. if not loop:
  92. loop = get_event_loop()
  93. coro = cls.create_async(model, messages, **kwargs)
  94. yield loop.run_until_complete(coro)
  95. @staticmethod
  96. @abstractmethod
  97. async def create_async(
  98. model: str,
  99. messages: Messages,
  100. **kwargs
  101. ) -> str:
  102. raise NotImplementedError()
  103. class AsyncGeneratorProvider(AsyncProvider):
  104. supports_stream = True
  105. @classmethod
  106. def create_completion(
  107. cls,
  108. model: str,
  109. messages: Messages,
  110. stream: bool = True,
  111. *,
  112. loop: AbstractEventLoop = None,
  113. **kwargs
  114. ) -> CreateResult:
  115. if not loop:
  116. loop = get_event_loop()
  117. generator = cls.create_async_generator(
  118. model,
  119. messages,
  120. stream=stream,
  121. **kwargs
  122. )
  123. gen = generator.__aiter__()
  124. while True:
  125. try:
  126. yield loop.run_until_complete(gen.__anext__())
  127. except StopAsyncIteration:
  128. break
  129. @classmethod
  130. async def create_async(
  131. cls,
  132. model: str,
  133. messages: Messages,
  134. **kwargs
  135. ) -> str:
  136. return "".join([
  137. chunk async for chunk in cls.create_async_generator(
  138. model,
  139. messages,
  140. stream=False,
  141. **kwargs
  142. )
  143. ])
  144. @staticmethod
  145. @abstractmethod
  146. def create_async_generator(
  147. model: str,
  148. messages: Messages,
  149. stream: bool = True,
  150. **kwargs
  151. ) -> AsyncResult:
  152. raise NotImplementedError()