base_provider.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. from __future__ import annotations
  2. import asyncio
  3. from asyncio import AbstractEventLoop
  4. from concurrent.futures import ThreadPoolExecutor
  5. from abc import abstractmethod
  6. import json
  7. from inspect import signature, Parameter
  8. from typing import Optional, _GenericAlias
  9. from pathlib import Path
  10. try:
  11. from types import NoneType
  12. except ImportError:
  13. NoneType = type(None)
  14. from ..typing import CreateResult, AsyncResult, Messages
  15. from .types import BaseProvider
  16. from .asyncio import get_running_loop, to_sync_generator, to_async_iterator
  17. from .response import BaseConversation, AuthResult
  18. from .helper import concat_chunks, async_concat_chunks
  19. from ..cookies import get_cookies_dir
  20. from ..errors import ModelNotSupportedError, ResponseError, MissingAuthError, NoValidHarFileError
  21. from .. import debug
  22. SAFE_PARAMETERS = [
  23. "model", "messages", "stream", "timeout",
  24. "proxy", "images", "response_format",
  25. "prompt", "negative_prompt", "tools", "conversation",
  26. "history_disabled", "auto_continue",
  27. "temperature", "top_k", "top_p",
  28. "frequency_penalty", "presence_penalty",
  29. "max_tokens", "max_new_tokens", "stop",
  30. "api_key", "api_base", "seed", "width", "height",
  31. "proof_token", "max_retries", "web_search",
  32. "guidance_scale", "num_inference_steps", "randomize_seed",
  33. ]
  34. BASIC_PARAMETERS = {
  35. "provider": None,
  36. "model": "",
  37. "messages": [],
  38. "stream": False,
  39. "timeout": 0,
  40. "response_format": None,
  41. "max_tokens": None,
  42. "stop": None,
  43. }
  44. PARAMETER_EXAMPLES = {
  45. "proxy": "http://user:password@127.0.0.1:3128",
  46. "temperature": 1,
  47. "top_k": 1,
  48. "top_p": 1,
  49. "frequency_penalty": 1,
  50. "presence_penalty": 1,
  51. "messages": [{"role": "system", "content": ""}, {"role": "user", "content": ""}],
  52. "images": [["data:image/jpeg;base64,...", "filename.jpg"]],
  53. "response_format": {"type": "json_object"},
  54. "conversation": {"conversation_id": "550e8400-e29b-11d4-a716-...", "message_id": "550e8400-e29b-11d4-a716-..."},
  55. "max_new_tokens": 1024,
  56. "max_tokens": 4096,
  57. "seed": 42,
  58. }
  59. class AbstractProvider(BaseProvider):
  60. @classmethod
  61. @abstractmethod
  62. def create_completion(
  63. cls,
  64. model: str,
  65. messages: Messages,
  66. stream: bool,
  67. **kwargs
  68. ) -> CreateResult:
  69. """
  70. Create a completion with the given parameters.
  71. Args:
  72. model (str): The model to use.
  73. messages (Messages): The messages to process.
  74. stream (bool): Whether to use streaming.
  75. **kwargs: Additional keyword arguments.
  76. Returns:
  77. CreateResult: The result of the creation process.
  78. """
  79. raise NotImplementedError()
  80. @classmethod
  81. async def create_async(
  82. cls,
  83. model: str,
  84. messages: Messages,
  85. *,
  86. timeout: int = None,
  87. loop: AbstractEventLoop = None,
  88. executor: ThreadPoolExecutor = None,
  89. **kwargs
  90. ) -> str:
  91. """
  92. Asynchronously creates a result based on the given model and messages.
  93. Args:
  94. cls (type): The class on which this method is called.
  95. model (str): The model to use for creation.
  96. messages (Messages): The messages to process.
  97. loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
  98. executor (ThreadPoolExecutor, optional): The executor for running async tasks. Defaults to None.
  99. **kwargs: Additional keyword arguments.
  100. Returns:
  101. str: The created result as a string.
  102. """
  103. loop = asyncio.get_running_loop() if loop is None else loop
  104. def create_func() -> str:
  105. return concat_chunks(cls.create_completion(model, messages, **kwargs))
  106. return await asyncio.wait_for(
  107. loop.run_in_executor(executor, create_func),
  108. timeout=timeout
  109. )
  110. @classmethod
  111. def get_create_function(cls) -> callable:
  112. return cls.create_completion
  113. @classmethod
  114. def get_async_create_function(cls) -> callable:
  115. return cls.create_async
  116. @classmethod
  117. def get_parameters(cls, as_json: bool = False) -> dict[str, Parameter]:
  118. params = {name: parameter for name, parameter in signature(
  119. cls.create_async_generator if issubclass(cls, AsyncGeneratorProvider) else
  120. cls.create_async if issubclass(cls, AsyncProvider) else
  121. cls.create_completion
  122. ).parameters.items() if name in SAFE_PARAMETERS
  123. and (name != "stream" or cls.supports_stream)}
  124. if as_json:
  125. def get_type_as_var(annotation: type, key: str, default):
  126. if key in PARAMETER_EXAMPLES:
  127. if key == "messages" and not cls.supports_system_message:
  128. return [PARAMETER_EXAMPLES[key][-1]]
  129. return PARAMETER_EXAMPLES[key]
  130. if isinstance(annotation, type):
  131. if issubclass(annotation, int):
  132. return 0
  133. elif issubclass(annotation, float):
  134. return 0.0
  135. elif issubclass(annotation, bool):
  136. return False
  137. elif issubclass(annotation, str):
  138. return ""
  139. elif issubclass(annotation, dict):
  140. return {}
  141. elif issubclass(annotation, list):
  142. return []
  143. elif issubclass(annotation, BaseConversation):
  144. return {}
  145. elif issubclass(annotation, NoneType):
  146. return {}
  147. elif annotation is None:
  148. return None
  149. elif annotation == "str" or annotation == "list[str]":
  150. return default
  151. elif isinstance(annotation, _GenericAlias):
  152. if annotation.__origin__ is Optional:
  153. return get_type_as_var(annotation.__args__[0])
  154. else:
  155. return str(annotation)
  156. return { name: (
  157. param.default
  158. if isinstance(param, Parameter) and param.default is not Parameter.empty and param.default is not None
  159. else get_type_as_var(param.annotation, name, param.default) if isinstance(param, Parameter) else param
  160. ) for name, param in {
  161. **BASIC_PARAMETERS,
  162. **params,
  163. **{"provider": cls.__name__, "model": getattr(cls, "default_model", ""), "stream": cls.supports_stream},
  164. }.items()}
  165. return params
  166. @classmethod
  167. @property
  168. def params(cls) -> str:
  169. """
  170. Returns the parameters supported by the provider.
  171. Args:
  172. cls (type): The class on which this property is called.
  173. Returns:
  174. str: A string listing the supported parameters.
  175. """
  176. def get_type_name(annotation: type) -> str:
  177. return getattr(annotation, "__name__", str(annotation)) if annotation is not Parameter.empty else ""
  178. args = ""
  179. for name, param in cls.get_parameters().items():
  180. args += f"\n {name}"
  181. args += f": {get_type_name(param.annotation)}"
  182. default_value = getattr(cls, "default_model", "") if name == "model" else param.default
  183. default_value = f'"{default_value}"' if isinstance(default_value, str) else default_value
  184. args += f" = {default_value}" if param.default is not Parameter.empty else ""
  185. args += ","
  186. return f"g4f.Provider.{cls.__name__} supports: ({args}\n)"
  187. class AsyncProvider(AbstractProvider):
  188. """
  189. Provides asynchronous functionality for creating completions.
  190. """
  191. @classmethod
  192. def create_completion(
  193. cls,
  194. model: str,
  195. messages: Messages,
  196. stream: bool = False,
  197. **kwargs
  198. ) -> CreateResult:
  199. """
  200. Creates a completion result synchronously.
  201. Args:
  202. cls (type): The class on which this method is called.
  203. model (str): The model to use for creation.
  204. messages (Messages): The messages to process.
  205. stream (bool): Indicates whether to stream the results. Defaults to False.
  206. loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
  207. **kwargs: Additional keyword arguments.
  208. Returns:
  209. CreateResult: The result of the completion creation.
  210. """
  211. get_running_loop(check_nested=False)
  212. yield asyncio.run(cls.create_async(model, messages, **kwargs))
  213. @staticmethod
  214. @abstractmethod
  215. async def create_async(
  216. model: str,
  217. messages: Messages,
  218. **kwargs
  219. ) -> str:
  220. """
  221. Abstract method for creating asynchronous results.
  222. Args:
  223. model (str): The model to use for creation.
  224. messages (Messages): The messages to process.
  225. **kwargs: Additional keyword arguments.
  226. Raises:
  227. NotImplementedError: If this method is not overridden in derived classes.
  228. Returns:
  229. str: The created result as a string.
  230. """
  231. raise NotImplementedError()
  232. @classmethod
  233. def get_create_function(cls) -> callable:
  234. return cls.create_completion
  235. @classmethod
  236. def get_async_create_function(cls) -> callable:
  237. return cls.create_async
  238. class AsyncGeneratorProvider(AbstractProvider):
  239. """
  240. Provides asynchronous generator functionality for streaming results.
  241. """
  242. supports_stream = True
  243. @classmethod
  244. def create_completion(
  245. cls,
  246. model: str,
  247. messages: Messages,
  248. stream: bool = True,
  249. **kwargs
  250. ) -> CreateResult:
  251. """
  252. Creates a streaming completion result synchronously.
  253. Args:
  254. cls (type): The class on which this method is called.
  255. model (str): The model to use for creation.
  256. messages (Messages): The messages to process.
  257. stream (bool): Indicates whether to stream the results. Defaults to True.
  258. loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
  259. **kwargs: Additional keyword arguments.
  260. Returns:
  261. CreateResult: The result of the streaming completion creation.
  262. """
  263. return to_sync_generator(
  264. cls.create_async_generator(model, messages, stream=stream, **kwargs),
  265. stream=stream
  266. )
  267. @staticmethod
  268. @abstractmethod
  269. async def create_async_generator(
  270. model: str,
  271. messages: Messages,
  272. stream: bool = True,
  273. **kwargs
  274. ) -> AsyncResult:
  275. """
  276. Abstract method for creating an asynchronous generator.
  277. Args:
  278. model (str): The model to use for creation.
  279. messages (Messages): The messages to process.
  280. stream (bool): Indicates whether to stream the results. Defaults to True.
  281. **kwargs: Additional keyword arguments.
  282. Raises:
  283. NotImplementedError: If this method is not overridden in derived classes.
  284. Returns:
  285. AsyncResult: An asynchronous generator yielding results.
  286. """
  287. raise NotImplementedError()
  288. @classmethod
  289. def get_create_function(cls) -> callable:
  290. return cls.create_completion
  291. @classmethod
  292. def get_async_create_function(cls) -> callable:
  293. return cls.create_async_generator
  294. class ProviderModelMixin:
  295. default_model: str = None
  296. models: list[str] = []
  297. model_aliases: dict[str, str] = {}
  298. image_models: list = None
  299. last_model: str = None
  300. @classmethod
  301. def get_models(cls, **kwargs) -> list[str]:
  302. if not cls.models and cls.default_model is not None:
  303. return [cls.default_model]
  304. return cls.models
  305. @classmethod
  306. def get_model(cls, model: str, **kwargs) -> str:
  307. if not model and cls.default_model is not None:
  308. model = cls.default_model
  309. elif model in cls.model_aliases:
  310. model = cls.model_aliases[model]
  311. else:
  312. if model not in cls.get_models(**kwargs) and cls.models:
  313. raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
  314. cls.last_model = model
  315. debug.last_model = model
  316. return model
  317. class RaiseErrorMixin():
  318. @staticmethod
  319. def raise_error(data: dict):
  320. if "error_message" in data:
  321. raise ResponseError(data["error_message"])
  322. elif "error" in data:
  323. if "code" in data["error"]:
  324. raise ResponseError(f'Error {data["error"]["code"]}: {data["error"]["message"]}')
  325. elif "message" in data["error"]:
  326. raise ResponseError(data["error"]["message"])
  327. else:
  328. raise ResponseError(data["error"])
  329. class AsyncAuthedProvider(AsyncGeneratorProvider):
  330. @classmethod
  331. async def on_auth_async(cls, **kwargs) -> AuthResult:
  332. if "api_key" not in kwargs:
  333. raise MissingAuthError(f"API key is required for {cls.__name__}")
  334. return AuthResult()
  335. @classmethod
  336. def on_auth(cls, **kwargs) -> AuthResult:
  337. auth_result = cls.on_auth_async(**kwargs)
  338. if hasattr(auth_result, "__aiter__"):
  339. return to_sync_generator(auth_result)
  340. return asyncio.run(auth_result)
  341. @classmethod
  342. def get_create_function(cls) -> callable:
  343. return cls.create_completion
  344. @classmethod
  345. def get_async_create_function(cls) -> callable:
  346. return cls.create_async_generator
  347. @classmethod
  348. def get_cache_file(cls) -> Path:
  349. return Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
  350. @classmethod
  351. def create_completion(
  352. cls,
  353. model: str,
  354. messages: Messages,
  355. **kwargs
  356. ) -> CreateResult:
  357. try:
  358. auth_result = AuthResult()
  359. cache_file = cls.get_cache_file()
  360. if cache_file.exists():
  361. with cache_file.open("r") as f:
  362. auth_result = AuthResult(**json.load(f))
  363. else:
  364. auth_result = cls.on_auth(**kwargs)
  365. try:
  366. for chunk in auth_result:
  367. if hasattr(chunk, "get_dict"):
  368. auth_result = chunk
  369. else:
  370. yield chunk
  371. except TypeError:
  372. pass
  373. yield from to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
  374. except (MissingAuthError, NoValidHarFileError):
  375. auth_result = cls.on_auth(**kwargs)
  376. try:
  377. for chunk in auth_result:
  378. if hasattr(chunk, "get_dict"):
  379. auth_result = chunk
  380. else:
  381. yield chunk
  382. except TypeError:
  383. pass
  384. yield from to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
  385. finally:
  386. if hasattr(auth_result, "get_dict"):
  387. data = auth_result.get_dict()
  388. cache_file.parent.mkdir(parents=True, exist_ok=True)
  389. cache_file.write_text(json.dumps(data))
  390. elif cache_file.exists():
  391. cache_file.unlink()
  392. @classmethod
  393. async def create_async_generator(
  394. cls,
  395. model: str,
  396. messages: Messages,
  397. **kwargs
  398. ) -> AsyncResult:
  399. try:
  400. auth_result = AuthResult()
  401. cache_file = Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
  402. if cache_file.exists():
  403. with cache_file.open("r") as f:
  404. auth_result = AuthResult(**json.load(f))
  405. else:
  406. auth_result = cls.on_auth_async(**kwargs)
  407. if hasattr(auth_result, "_aiter__"):
  408. async for chunk in auth_result:
  409. if isinstance(chunk, AsyncResult):
  410. auth_result = chunk
  411. else:
  412. yield chunk
  413. else:
  414. auth_result = await auth_result
  415. response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
  416. async for chunk in response:
  417. yield chunk
  418. except (MissingAuthError, NoValidHarFileError):
  419. if cache_file.exists():
  420. cache_file.unlink()
  421. auth_result = cls.on_auth_async(**kwargs)
  422. if hasattr(auth_result, "_aiter__"):
  423. async for chunk in auth_result:
  424. if isinstance(chunk, AsyncResult):
  425. auth_result = chunk
  426. else:
  427. yield chunk
  428. else:
  429. auth_result = await auth_result
  430. response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
  431. async for chunk in response:
  432. yield chunk
  433. finally:
  434. if hasattr(auth_result, "get_dict"):
  435. cache_file.parent.mkdir(parents=True, exist_ok=True)
  436. cache_file.write_text(json.dumps(auth_result.get_dict()))
  437. elif cache_file.exists():
  438. cache_file.unlink()