api.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203
  1. from __future__ import annotations
  2. import logging
  3. import os
  4. import asyncio
  5. from typing import Iterator
  6. from flask import send_from_directory
  7. from inspect import signature
  8. from g4f import version, models
  9. from g4f import get_last_provider, ChatCompletion
  10. from g4f.errors import VersionNotFoundError
  11. from g4f.image import ImagePreview, ImageResponse, copy_images, ensure_images_dir, images_dir
  12. from g4f.Provider import ProviderType, __providers__, __map__
  13. from g4f.providers.base_provider import ProviderModelMixin
  14. from g4f.providers.response import BaseConversation, FinishReason, SynthesizeData
  15. from g4f.client.service import convert_to_provider
  16. from g4f import debug
  17. logger = logging.getLogger(__name__)
  18. conversations: dict[dict[str, BaseConversation]] = {}
  19. class Api:
  20. @staticmethod
  21. def get_models() -> list[str]:
  22. return models._all_models
  23. @staticmethod
  24. def get_provider_models(provider: str, api_key: str = None) -> list[dict]:
  25. if provider in __map__:
  26. provider: ProviderType = __map__[provider]
  27. if issubclass(provider, ProviderModelMixin):
  28. if api_key is not None and "api_key" in signature(provider.get_models).parameters:
  29. models = provider.get_models(api_key=api_key)
  30. else:
  31. models = provider.get_models()
  32. return [
  33. {
  34. "model": model,
  35. "default": model == provider.default_model,
  36. "vision": getattr(provider, "default_vision_model", None) == model or model in getattr(provider, "vision_models", []),
  37. "image": False if provider.image_models is None else model in provider.image_models,
  38. }
  39. for model in models
  40. ]
  41. return []
  42. @staticmethod
  43. def get_image_models() -> list[dict]:
  44. image_models = []
  45. index = []
  46. for provider in __providers__:
  47. if hasattr(provider, "image_models"):
  48. if hasattr(provider, "get_models"):
  49. provider.get_models()
  50. parent = provider
  51. if hasattr(provider, "parent"):
  52. parent = __map__[provider.parent]
  53. if parent.__name__ not in index:
  54. for model in provider.image_models:
  55. image_models.append({
  56. "provider": parent.__name__,
  57. "url": parent.url,
  58. "label": parent.label if hasattr(parent, "label") else None,
  59. "image_model": model,
  60. "vision_model": getattr(parent, "default_vision_model", None)
  61. })
  62. index.append(parent.__name__)
  63. elif hasattr(provider, "default_vision_model") and provider.__name__ not in index:
  64. image_models.append({
  65. "provider": provider.__name__,
  66. "url": provider.url,
  67. "label": provider.label if hasattr(provider, "label") else None,
  68. "image_model": None,
  69. "vision_model": provider.default_vision_model
  70. })
  71. index.append(provider.__name__)
  72. return image_models
  73. @staticmethod
  74. def get_providers() -> list[str]:
  75. return {
  76. provider.__name__: (provider.label if hasattr(provider, "label") else provider.__name__)
  77. + (" (Image Generation)" if getattr(provider, "image_models", None) else "")
  78. + (" (Image Upload)" if getattr(provider, "default_vision_model", None) else "")
  79. + (" (WebDriver)" if "webdriver" in provider.get_parameters() else "")
  80. + (" (Auth)" if provider.needs_auth else "")
  81. for provider in __providers__
  82. if provider.working
  83. }
  84. @staticmethod
  85. def get_version():
  86. try:
  87. current_version = version.utils.current_version
  88. except VersionNotFoundError:
  89. current_version = None
  90. return {
  91. "version": current_version,
  92. "latest_version": version.utils.latest_version,
  93. }
  94. def serve_images(self, name):
  95. ensure_images_dir()
  96. return send_from_directory(os.path.abspath(images_dir), name)
  97. def _prepare_conversation_kwargs(self, json_data: dict, kwargs: dict):
  98. model = json_data.get('model') or models.default
  99. provider = json_data.get('provider')
  100. messages = json_data['messages']
  101. api_key = json_data.get("api_key")
  102. if api_key is not None:
  103. kwargs["api_key"] = api_key
  104. do_web_search = json_data.get('web_search')
  105. if do_web_search and provider:
  106. provider_handler = convert_to_provider(provider)
  107. if hasattr(provider_handler, "get_parameters"):
  108. if "web_search" in provider_handler.get_parameters():
  109. kwargs['web_search'] = True
  110. do_web_search = False
  111. if do_web_search:
  112. from .internet import get_search_message
  113. messages[-1]["content"] = get_search_message(messages[-1]["content"])
  114. if json_data.get("auto_continue"):
  115. kwargs['auto_continue'] = True
  116. conversation_id = json_data.get("conversation_id")
  117. if conversation_id and provider:
  118. if provider in conversations and conversation_id in conversations[provider]:
  119. kwargs["conversation"] = conversations[provider][conversation_id]
  120. return {
  121. "model": model,
  122. "provider": provider,
  123. "messages": messages,
  124. "stream": True,
  125. "ignore_stream": True,
  126. "return_conversation": True,
  127. **kwargs
  128. }
  129. def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str, download_images: bool = True) -> Iterator:
  130. if debug.logging:
  131. debug.logs = []
  132. print_callback = debug.log_handler
  133. def log_handler(text: str):
  134. debug.logs.append(text)
  135. print_callback(text)
  136. debug.log_handler = log_handler
  137. try:
  138. result = ChatCompletion.create(**kwargs)
  139. first = True
  140. if isinstance(result, ImageResponse):
  141. if first:
  142. first = False
  143. yield self._format_json("provider", get_last_provider(True))
  144. yield self._format_json("content", str(result))
  145. else:
  146. for chunk in result:
  147. if first:
  148. first = False
  149. yield self._format_json("provider", get_last_provider(True))
  150. if isinstance(chunk, BaseConversation):
  151. if provider:
  152. if provider not in conversations:
  153. conversations[provider] = {}
  154. conversations[provider][conversation_id] = chunk
  155. yield self._format_json("conversation", conversation_id)
  156. elif isinstance(chunk, Exception):
  157. logger.exception(chunk)
  158. yield self._format_json("message", get_error_message(chunk))
  159. elif isinstance(chunk, ImagePreview):
  160. yield self._format_json("preview", chunk.to_string())
  161. elif isinstance(chunk, ImageResponse):
  162. images = chunk
  163. if download_images:
  164. images = asyncio.run(copy_images(chunk.get_list(), chunk.options.get("cookies")))
  165. images = ImageResponse(images, chunk.alt)
  166. yield self._format_json("content", str(images))
  167. elif isinstance(chunk, SynthesizeData):
  168. yield self._format_json("synthesize", chunk.to_json())
  169. elif not isinstance(chunk, FinishReason):
  170. yield self._format_json("content", str(chunk))
  171. if debug.logs:
  172. for log in debug.logs:
  173. yield self._format_json("log", str(log))
  174. debug.logs = []
  175. except Exception as e:
  176. logger.exception(e)
  177. yield self._format_json('error', get_error_message(e))
  178. def _format_json(self, response_type: str, content):
  179. return {
  180. 'type': response_type,
  181. response_type: content
  182. }
  183. def get_error_message(exception: Exception) -> str:
  184. message = f"{type(exception).__name__}: {exception}"
  185. provider = get_last_provider()
  186. if provider is None:
  187. return message
  188. return f"{provider.__name__}: {message}"