GeminiPro.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. from __future__ import annotations
  2. import base64
  3. import json
  4. import requests
  5. from typing import Optional
  6. from aiohttp import ClientSession, BaseConnector
  7. from ...typing import AsyncResult, Messages, ImagesType
  8. from ...image import to_bytes, is_accepted_format
  9. from ...errors import MissingAuthError
  10. from ...requests.raise_for_status import raise_for_status
  11. from ...providers.response import Usage, FinishReason
  12. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
  13. from ..helper import get_connector
  14. from ... import debug
  15. class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
  16. label = "Google Gemini API"
  17. url = "https://ai.google.dev"
  18. login_url = "https://aistudio.google.com/u/0/apikey"
  19. api_base = "https://generativelanguage.googleapis.com/v1beta"
  20. working = True
  21. supports_message_history = True
  22. supports_system_message = True
  23. needs_auth = True
  24. default_model = "gemini-1.5-pro"
  25. default_vision_model = default_model
  26. fallback_models = [default_model, "gemini-2.0-flash-exp", "gemini-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b"]
  27. model_aliases = {
  28. "gemini-1.5-flash": "gemini-1.5-flash",
  29. "gemini-1.5-flash": "gemini-1.5-flash-8b",
  30. "gemini-1.5-pro": "gemini-pro",
  31. "gemini-2.0-flash": "gemini-2.0-flash-exp",
  32. }
  33. @classmethod
  34. def get_models(cls, api_key: str = None, api_base: str = api_base) -> list[str]:
  35. if not cls.models:
  36. try:
  37. url = f"{cls.api_base if not api_base else api_base}/models"
  38. response = requests.get(url, params={"key": api_key})
  39. raise_for_status(response)
  40. data = response.json()
  41. cls.models = [
  42. model.get("name").split("/").pop()
  43. for model in data.get("models")
  44. if "generateContent" in model.get("supportedGenerationMethods")
  45. ]
  46. cls.models.sort()
  47. except Exception as e:
  48. debug.log(e)
  49. return cls.fallback_models
  50. return cls.models
  51. @classmethod
  52. async def create_async_generator(
  53. cls,
  54. model: str,
  55. messages: Messages,
  56. stream: bool = False,
  57. proxy: str = None,
  58. api_key: str = None,
  59. api_base: str = api_base,
  60. use_auth_header: bool = False,
  61. images: ImagesType = None,
  62. tools: Optional[list] = None,
  63. connector: BaseConnector = None,
  64. **kwargs
  65. ) -> AsyncResult:
  66. if not api_key:
  67. raise MissingAuthError('Add a "api_key"')
  68. model = cls.get_model(model, api_key=api_key, api_base=api_base)
  69. headers = params = None
  70. if use_auth_header:
  71. headers = {"Authorization": f"Bearer {api_key}"}
  72. else:
  73. params = {"key": api_key}
  74. method = "streamGenerateContent" if stream else "generateContent"
  75. url = f"{api_base.rstrip('/')}/models/{model}:{method}"
  76. async with ClientSession(headers=headers, connector=get_connector(connector, proxy)) as session:
  77. contents = [
  78. {
  79. "role": "model" if message["role"] == "assistant" else "user",
  80. "parts": [{"text": message["content"]}]
  81. }
  82. for message in messages
  83. if message["role"] != "system"
  84. ]
  85. if images is not None:
  86. for image, _ in images:
  87. image = to_bytes(image)
  88. contents[-1]["parts"].append({
  89. "inline_data": {
  90. "mime_type": is_accepted_format(image),
  91. "data": base64.b64encode(image).decode()
  92. }
  93. })
  94. data = {
  95. "contents": contents,
  96. "generationConfig": {
  97. "stopSequences": kwargs.get("stop"),
  98. "temperature": kwargs.get("temperature"),
  99. "maxOutputTokens": kwargs.get("max_tokens"),
  100. "topP": kwargs.get("top_p"),
  101. "topK": kwargs.get("top_k"),
  102. },
  103. "tools": [{
  104. "functionDeclarations": tools
  105. }] if tools else None
  106. }
  107. system_prompt = "\n".join(
  108. message["content"]
  109. for message in messages
  110. if message["role"] == "system"
  111. )
  112. if system_prompt:
  113. data["system_instruction"] = {"parts": {"text": system_prompt}}
  114. async with session.post(url, params=params, json=data) as response:
  115. if not response.ok:
  116. data = await response.json()
  117. data = data[0] if isinstance(data, list) else data
  118. raise RuntimeError(f"Response {response.status}: {data['error']['message']}")
  119. if stream:
  120. lines = []
  121. async for chunk in response.content:
  122. if chunk == b"[{\n":
  123. lines = [b"{\n"]
  124. elif chunk == b",\r\n" or chunk == b"]":
  125. try:
  126. data = b"".join(lines)
  127. data = json.loads(data)
  128. yield data["candidates"][0]["content"]["parts"][0]["text"]
  129. if "finishReason" in data["candidates"][0]:
  130. yield FinishReason(data["candidates"][0]["finishReason"].lower())
  131. usage = data.get("usageMetadata")
  132. if usage:
  133. yield Usage(
  134. prompt_tokens=usage.get("promptTokenCount"),
  135. completion_tokens=usage.get("candidatesTokenCount"),
  136. total_tokens=usage.get("totalTokenCount")
  137. )
  138. except:
  139. data = data.decode(errors="ignore") if isinstance(data, bytes) else data
  140. raise RuntimeError(f"Read chunk failed: {data}")
  141. lines = []
  142. else:
  143. lines.append(chunk)
  144. else:
  145. data = await response.json()
  146. candidate = data["candidates"][0]
  147. if candidate["finishReason"] == "STOP":
  148. yield candidate["content"]["parts"][0]["text"]
  149. else:
  150. yield candidate["finishReason"] + ' ' + candidate["safetyRatings"]