base_provider.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. from __future__ import annotations
  2. import asyncio
  3. from asyncio import AbstractEventLoop
  4. from concurrent.futures import ThreadPoolExecutor
  5. from abc import abstractmethod
  6. import json
  7. from inspect import signature, Parameter
  8. from typing import Optional, _GenericAlias
  9. from pathlib import Path
  10. try:
  11. from types import NoneType
  12. except ImportError:
  13. NoneType = type(None)
  14. from ..typing import CreateResult, AsyncResult, Messages
  15. from .types import BaseProvider
  16. from .asyncio import get_running_loop, to_sync_generator, to_async_iterator
  17. from .response import BaseConversation, AuthResult
  18. from .helper import concat_chunks, async_concat_chunks
  19. from ..cookies import get_cookies_dir
  20. from ..errors import ModelNotSupportedError, ResponseError, MissingAuthError, NoValidHarFileError
  21. from .. import debug
  22. SAFE_PARAMETERS = [
  23. "model", "messages", "stream", "timeout",
  24. "proxy", "images", "response_format",
  25. "prompt", "negative_prompt", "tools", "conversation",
  26. "history_disabled", "auto_continue",
  27. "temperature", "top_k", "top_p",
  28. "frequency_penalty", "presence_penalty",
  29. "max_tokens", "max_new_tokens", "stop",
  30. "api_key", "api_base", "seed", "width", "height",
  31. "proof_token", "max_retries", "web_search",
  32. "guidance_scale", "num_inference_steps", "randomize_seed",
  33. "safe", "enhance", "private",
  34. ]
  35. BASIC_PARAMETERS = {
  36. "provider": None,
  37. "model": "",
  38. "messages": [],
  39. "stream": False,
  40. "timeout": 0,
  41. "response_format": None,
  42. "max_tokens": None,
  43. "stop": None,
  44. }
  45. PARAMETER_EXAMPLES = {
  46. "proxy": "http://user:password@127.0.0.1:3128",
  47. "temperature": 1,
  48. "top_k": 1,
  49. "top_p": 1,
  50. "frequency_penalty": 1,
  51. "presence_penalty": 1,
  52. "messages": [{"role": "system", "content": ""}, {"role": "user", "content": ""}],
  53. "images": [["data:image/jpeg;base64,...", "filename.jpg"]],
  54. "response_format": {"type": "json_object"},
  55. "conversation": {"conversation_id": "550e8400-e29b-11d4-a716-...", "message_id": "550e8400-e29b-11d4-a716-..."},
  56. "max_new_tokens": 1024,
  57. "max_tokens": 4096,
  58. "seed": 42,
  59. "stop": ["stop1", "stop2"],
  60. "tools": [],
  61. }
  62. class AbstractProvider(BaseProvider):
  63. @classmethod
  64. @abstractmethod
  65. def create_completion(
  66. cls,
  67. model: str,
  68. messages: Messages,
  69. stream: bool,
  70. **kwargs
  71. ) -> CreateResult:
  72. """
  73. Create a completion with the given parameters.
  74. Args:
  75. model (str): The model to use.
  76. messages (Messages): The messages to process.
  77. stream (bool): Whether to use streaming.
  78. **kwargs: Additional keyword arguments.
  79. Returns:
  80. CreateResult: The result of the creation process.
  81. """
  82. raise NotImplementedError()
  83. @classmethod
  84. async def create_async(
  85. cls,
  86. model: str,
  87. messages: Messages,
  88. *,
  89. timeout: int = None,
  90. loop: AbstractEventLoop = None,
  91. executor: ThreadPoolExecutor = None,
  92. **kwargs
  93. ) -> str:
  94. """
  95. Asynchronously creates a result based on the given model and messages.
  96. Args:
  97. cls (type): The class on which this method is called.
  98. model (str): The model to use for creation.
  99. messages (Messages): The messages to process.
  100. loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
  101. executor (ThreadPoolExecutor, optional): The executor for running async tasks. Defaults to None.
  102. **kwargs: Additional keyword arguments.
  103. Returns:
  104. str: The created result as a string.
  105. """
  106. loop = asyncio.get_running_loop() if loop is None else loop
  107. def create_func() -> str:
  108. return concat_chunks(cls.create_completion(model, messages, **kwargs))
  109. return await asyncio.wait_for(
  110. loop.run_in_executor(executor, create_func),
  111. timeout=timeout
  112. )
  113. @classmethod
  114. def get_create_function(cls) -> callable:
  115. return cls.create_completion
  116. @classmethod
  117. def get_async_create_function(cls) -> callable:
  118. return cls.create_async
  119. @classmethod
  120. def get_parameters(cls, as_json: bool = False) -> dict[str, Parameter]:
  121. params = {name: parameter for name, parameter in signature(
  122. cls.create_async_generator if issubclass(cls, AsyncGeneratorProvider) else
  123. cls.create_async if issubclass(cls, AsyncProvider) else
  124. cls.create_completion
  125. ).parameters.items() if name in SAFE_PARAMETERS
  126. and (name != "stream" or cls.supports_stream)}
  127. if as_json:
  128. def get_type_as_var(annotation: type, key: str, default):
  129. if key in PARAMETER_EXAMPLES:
  130. if key == "messages" and not cls.supports_system_message:
  131. return [PARAMETER_EXAMPLES[key][-1]]
  132. return PARAMETER_EXAMPLES[key]
  133. if isinstance(annotation, type):
  134. if issubclass(annotation, int):
  135. return 0
  136. elif issubclass(annotation, float):
  137. return 0.0
  138. elif issubclass(annotation, bool):
  139. return False
  140. elif issubclass(annotation, str):
  141. return ""
  142. elif issubclass(annotation, dict):
  143. return {}
  144. elif issubclass(annotation, list):
  145. return []
  146. elif issubclass(annotation, BaseConversation):
  147. return {}
  148. elif issubclass(annotation, NoneType):
  149. return {}
  150. elif annotation is None:
  151. return None
  152. elif annotation == "str" or annotation == "list[str]":
  153. return default
  154. elif isinstance(annotation, _GenericAlias):
  155. if annotation.__origin__ is Optional:
  156. return get_type_as_var(annotation.__args__[0])
  157. else:
  158. return str(annotation)
  159. return { name: (
  160. param.default
  161. if isinstance(param, Parameter) and param.default is not Parameter.empty and param.default is not None
  162. else get_type_as_var(param.annotation, name, param.default) if isinstance(param, Parameter) else param
  163. ) for name, param in {
  164. **BASIC_PARAMETERS,
  165. **params,
  166. **{"provider": cls.__name__, "model": getattr(cls, "default_model", ""), "stream": cls.supports_stream},
  167. }.items()}
  168. return params
  169. @classmethod
  170. @property
  171. def params(cls) -> str:
  172. """
  173. Returns the parameters supported by the provider.
  174. Args:
  175. cls (type): The class on which this property is called.
  176. Returns:
  177. str: A string listing the supported parameters.
  178. """
  179. def get_type_name(annotation: type) -> str:
  180. return getattr(annotation, "__name__", str(annotation)) if annotation is not Parameter.empty else ""
  181. args = ""
  182. for name, param in cls.get_parameters().items():
  183. args += f"\n {name}"
  184. args += f": {get_type_name(param.annotation)}"
  185. default_value = getattr(cls, "default_model", "") if name == "model" else param.default
  186. default_value = f'"{default_value}"' if isinstance(default_value, str) else default_value
  187. args += f" = {default_value}" if param.default is not Parameter.empty else ""
  188. args += ","
  189. return f"g4f.Provider.{cls.__name__} supports: ({args}\n)"
  190. class AsyncProvider(AbstractProvider):
  191. """
  192. Provides asynchronous functionality for creating completions.
  193. """
  194. @classmethod
  195. def create_completion(
  196. cls,
  197. model: str,
  198. messages: Messages,
  199. stream: bool = False,
  200. **kwargs
  201. ) -> CreateResult:
  202. """
  203. Creates a completion result synchronously.
  204. Args:
  205. cls (type): The class on which this method is called.
  206. model (str): The model to use for creation.
  207. messages (Messages): The messages to process.
  208. stream (bool): Indicates whether to stream the results. Defaults to False.
  209. loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
  210. **kwargs: Additional keyword arguments.
  211. Returns:
  212. CreateResult: The result of the completion creation.
  213. """
  214. get_running_loop(check_nested=False)
  215. yield asyncio.run(cls.create_async(model, messages, **kwargs))
  216. @staticmethod
  217. @abstractmethod
  218. async def create_async(
  219. model: str,
  220. messages: Messages,
  221. **kwargs
  222. ) -> str:
  223. """
  224. Abstract method for creating asynchronous results.
  225. Args:
  226. model (str): The model to use for creation.
  227. messages (Messages): The messages to process.
  228. **kwargs: Additional keyword arguments.
  229. Raises:
  230. NotImplementedError: If this method is not overridden in derived classes.
  231. Returns:
  232. str: The created result as a string.
  233. """
  234. raise NotImplementedError()
  235. @classmethod
  236. def get_create_function(cls) -> callable:
  237. return cls.create_completion
  238. @classmethod
  239. def get_async_create_function(cls) -> callable:
  240. return cls.create_async
  241. class AsyncGeneratorProvider(AbstractProvider):
  242. """
  243. Provides asynchronous generator functionality for streaming results.
  244. """
  245. supports_stream = True
  246. @classmethod
  247. def create_completion(
  248. cls,
  249. model: str,
  250. messages: Messages,
  251. stream: bool = True,
  252. **kwargs
  253. ) -> CreateResult:
  254. """
  255. Creates a streaming completion result synchronously.
  256. Args:
  257. cls (type): The class on which this method is called.
  258. model (str): The model to use for creation.
  259. messages (Messages): The messages to process.
  260. stream (bool): Indicates whether to stream the results. Defaults to True.
  261. loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
  262. **kwargs: Additional keyword arguments.
  263. Returns:
  264. CreateResult: The result of the streaming completion creation.
  265. """
  266. return to_sync_generator(
  267. cls.create_async_generator(model, messages, stream=stream, **kwargs),
  268. stream=stream
  269. )
  270. @staticmethod
  271. @abstractmethod
  272. async def create_async_generator(
  273. model: str,
  274. messages: Messages,
  275. stream: bool = True,
  276. **kwargs
  277. ) -> AsyncResult:
  278. """
  279. Abstract method for creating an asynchronous generator.
  280. Args:
  281. model (str): The model to use for creation.
  282. messages (Messages): The messages to process.
  283. stream (bool): Indicates whether to stream the results. Defaults to True.
  284. **kwargs: Additional keyword arguments.
  285. Raises:
  286. NotImplementedError: If this method is not overridden in derived classes.
  287. Returns:
  288. AsyncResult: An asynchronous generator yielding results.
  289. """
  290. raise NotImplementedError()
  291. @classmethod
  292. def get_create_function(cls) -> callable:
  293. return cls.create_completion
  294. @classmethod
  295. def get_async_create_function(cls) -> callable:
  296. return cls.create_async_generator
  297. class ProviderModelMixin:
  298. default_model: str = None
  299. models: list[str] = []
  300. model_aliases: dict[str, str] = {}
  301. image_models: list = []
  302. vision_models: list = []
  303. last_model: str = None
  304. @classmethod
  305. def get_models(cls, **kwargs) -> list[str]:
  306. if not cls.models and cls.default_model is not None:
  307. return [cls.default_model]
  308. return cls.models
  309. @classmethod
  310. def get_model(cls, model: str, **kwargs) -> str:
  311. if not model and cls.default_model is not None:
  312. model = cls.default_model
  313. elif model in cls.model_aliases:
  314. model = cls.model_aliases[model]
  315. else:
  316. if model not in cls.get_models(**kwargs) and cls.models:
  317. raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
  318. cls.last_model = model
  319. debug.last_model = model
  320. return model
  321. class RaiseErrorMixin():
  322. @staticmethod
  323. def raise_error(data: dict):
  324. if "error_message" in data:
  325. raise ResponseError(data["error_message"])
  326. elif "error" in data:
  327. if "code" in data["error"]:
  328. raise ResponseError(f'Error {data["error"]["code"]}: {data["error"]["message"]}')
  329. elif "message" in data["error"]:
  330. raise ResponseError(data["error"]["message"])
  331. else:
  332. raise ResponseError(data["error"])
  333. elif "choices" not in data or not data["choices"]:
  334. raise ResponseError(f"Invalid response: {json.dumps(data)}")
  335. class AsyncAuthedProvider(AsyncGeneratorProvider):
  336. @classmethod
  337. async def on_auth_async(cls, **kwargs) -> AuthResult:
  338. if "api_key" not in kwargs:
  339. raise MissingAuthError(f"API key is required for {cls.__name__}")
  340. return AuthResult()
  341. @classmethod
  342. def on_auth(cls, **kwargs) -> AuthResult:
  343. auth_result = cls.on_auth_async(**kwargs)
  344. if hasattr(auth_result, "__aiter__"):
  345. return to_sync_generator(auth_result)
  346. return asyncio.run(auth_result)
  347. @classmethod
  348. def get_create_function(cls) -> callable:
  349. return cls.create_completion
  350. @classmethod
  351. def get_async_create_function(cls) -> callable:
  352. return cls.create_async_generator
  353. @classmethod
  354. def get_cache_file(cls) -> Path:
  355. return Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
  356. @classmethod
  357. def create_completion(
  358. cls,
  359. model: str,
  360. messages: Messages,
  361. **kwargs
  362. ) -> CreateResult:
  363. try:
  364. auth_result = AuthResult()
  365. cache_file = cls.get_cache_file()
  366. if cache_file.exists():
  367. with cache_file.open("r") as f:
  368. auth_result = AuthResult(**json.load(f))
  369. else:
  370. auth_result = cls.on_auth(**kwargs)
  371. for chunk in auth_result:
  372. if hasattr(chunk, "get_dict"):
  373. auth_result = chunk
  374. else:
  375. yield chunk
  376. yield from to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
  377. except (MissingAuthError, NoValidHarFileError):
  378. auth_result = cls.on_auth(**kwargs)
  379. for chunk in auth_result:
  380. if hasattr(chunk, "get_dict"):
  381. auth_result = chunk
  382. else:
  383. yield chunk
  384. yield from to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
  385. finally:
  386. if hasattr(auth_result, "get_dict"):
  387. data = auth_result.get_dict()
  388. cache_file.parent.mkdir(parents=True, exist_ok=True)
  389. cache_file.write_text(json.dumps(data))
  390. elif cache_file.exists():
  391. cache_file.unlink()
  392. @classmethod
  393. async def create_async_generator(
  394. cls,
  395. model: str,
  396. messages: Messages,
  397. **kwargs
  398. ) -> AsyncResult:
  399. try:
  400. auth_result = AuthResult()
  401. cache_file = Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
  402. if cache_file.exists():
  403. with cache_file.open("r") as f:
  404. auth_result = AuthResult(**json.load(f))
  405. else:
  406. auth_result = cls.on_auth_async(**kwargs)
  407. async for chunk in auth_result:
  408. if hasattr(chunk, "get_dict"):
  409. auth_result = chunk
  410. else:
  411. yield chunk
  412. response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
  413. async for chunk in response:
  414. yield chunk
  415. except (MissingAuthError, NoValidHarFileError):
  416. if cache_file.exists():
  417. cache_file.unlink()
  418. auth_result = cls.on_auth_async(**kwargs)
  419. async for chunk in auth_result:
  420. if hasattr(chunk, "get_dict"):
  421. auth_result = chunk
  422. else:
  423. yield chunk
  424. response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
  425. async for chunk in response:
  426. yield chunk
  427. finally:
  428. if hasattr(auth_result, "get_dict"):
  429. cache_file.parent.mkdir(parents=True, exist_ok=True)
  430. cache_file.write_text(json.dumps(auth_result.get_dict()))
  431. elif cache_file.exists():
  432. cache_file.unlink()