GeminiPro.py 3.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. from __future__ import annotations
  2. import base64
  3. import json
  4. from aiohttp import ClientSession, BaseConnector
  5. from ..typing import AsyncResult, Messages, ImageType
  6. from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
  7. from ..image import to_bytes, is_accepted_format
  8. from ..errors import MissingAuthError
  9. from .helper import get_connector
  10. class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
  11. url = "https://ai.google.dev"
  12. working = True
  13. supports_message_history = True
  14. needs_auth = True
  15. default_model = "gemini-pro"
  16. models = ["gemini-pro", "gemini-pro-vision"]
  17. @classmethod
  18. async def create_async_generator(
  19. cls,
  20. model: str,
  21. messages: Messages,
  22. stream: bool = False,
  23. proxy: str = None,
  24. api_key: str = None,
  25. api_base: str = "https://generativelanguage.googleapis.com/v1beta",
  26. use_auth_header: bool = False,
  27. image: ImageType = None,
  28. connector: BaseConnector = None,
  29. **kwargs
  30. ) -> AsyncResult:
  31. model = "gemini-pro-vision" if not model and image is not None else model
  32. model = cls.get_model(model)
  33. if not api_key:
  34. raise MissingAuthError('Missing "api_key"')
  35. headers = params = None
  36. if use_auth_header:
  37. headers = {"Authorization": f"Bearer {api_key}"}
  38. else:
  39. params = {"key": api_key}
  40. method = "streamGenerateContent" if stream else "generateContent"
  41. url = f"{api_base.rstrip('/')}/models/{model}:{method}"
  42. async with ClientSession(headers=headers, connector=get_connector(connector, proxy)) as session:
  43. contents = [
  44. {
  45. "role": "model" if message["role"] == "assistant" else "user",
  46. "parts": [{"text": message["content"]}]
  47. }
  48. for message in messages
  49. ]
  50. if image is not None:
  51. image = to_bytes(image)
  52. contents[-1]["parts"].append({
  53. "inline_data": {
  54. "mime_type": is_accepted_format(image),
  55. "data": base64.b64encode(image).decode()
  56. }
  57. })
  58. data = {
  59. "contents": contents,
  60. "generationConfig": {
  61. "stopSequences": kwargs.get("stop"),
  62. "temperature": kwargs.get("temperature"),
  63. "maxOutputTokens": kwargs.get("max_tokens"),
  64. "topP": kwargs.get("top_p"),
  65. "topK": kwargs.get("top_k"),
  66. }
  67. }
  68. async with session.post(url, params=params, json=data) as response:
  69. if not response.ok:
  70. data = await response.json()
  71. data = data[0] if isinstance(data, list) else data
  72. raise RuntimeError(data["error"]["message"])
  73. if stream:
  74. lines = []
  75. async for chunk in response.content:
  76. if chunk == b"[{\n":
  77. lines = [b"{\n"]
  78. elif chunk == b",\r\n" or chunk == b"]":
  79. try:
  80. data = b"".join(lines)
  81. data = json.loads(data)
  82. yield data["candidates"][0]["content"]["parts"][0]["text"]
  83. except:
  84. data = data.decode() if isinstance(data, bytes) else data
  85. raise RuntimeError(f"Read chunk failed: {data}")
  86. lines = []
  87. else:
  88. lines.append(chunk)
  89. else:
  90. data = await response.json()
  91. yield data["candidates"][0]["content"]["parts"][0]["text"]