OpenaiAPI.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. from __future__ import annotations
  2. import json
  3. from ..helper import filter_none
  4. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, FinishReason
  5. from ...typing import Union, Optional, AsyncResult, Messages, ImageType
  6. from ...requests import StreamSession, raise_for_status
  7. from ...errors import MissingAuthError, ResponseError
  8. from ...image import to_data_uri
  9. class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin):
  10. label = "OpenAI API"
  11. url = "https://platform.openai.com"
  12. working = True
  13. needs_auth = True
  14. supports_message_history = True
  15. supports_system_message = True
  16. default_model = ""
  17. @classmethod
  18. async def create_async_generator(
  19. cls,
  20. model: str,
  21. messages: Messages,
  22. proxy: str = None,
  23. timeout: int = 120,
  24. image: ImageType = None,
  25. api_key: str = None,
  26. api_base: str = "https://api.openai.com/v1",
  27. temperature: float = None,
  28. max_tokens: int = None,
  29. top_p: float = None,
  30. stop: Union[str, list[str]] = None,
  31. stream: bool = False,
  32. headers: dict = None,
  33. impersonate: str = None,
  34. extra_data: dict = {},
  35. **kwargs
  36. ) -> AsyncResult:
  37. if cls.needs_auth and api_key is None:
  38. raise MissingAuthError('Add a "api_key"')
  39. if image is not None:
  40. if not model and hasattr(cls, "default_vision_model"):
  41. model = cls.default_vision_model
  42. messages[-1]["content"] = [
  43. {
  44. "type": "image_url",
  45. "image_url": {"url": to_data_uri(image)}
  46. },
  47. {
  48. "type": "text",
  49. "text": messages[-1]["content"]
  50. }
  51. ]
  52. async with StreamSession(
  53. proxies={"all": proxy},
  54. headers=cls.get_headers(stream, api_key, headers),
  55. timeout=timeout,
  56. impersonate=impersonate,
  57. ) as session:
  58. data = filter_none(
  59. messages=messages,
  60. model=cls.get_model(model),
  61. temperature=temperature,
  62. max_tokens=max_tokens,
  63. top_p=top_p,
  64. stop=stop,
  65. stream=stream,
  66. **extra_data
  67. )
  68. async with session.post(f"{api_base.rstrip('/')}/chat/completions", json=data) as response:
  69. await raise_for_status(response)
  70. if not stream:
  71. data = await response.json()
  72. cls.raise_error(data)
  73. choice = data["choices"][0]
  74. if "content" in choice["message"]:
  75. yield choice["message"]["content"].strip()
  76. finish = cls.read_finish_reason(choice)
  77. if finish is not None:
  78. yield finish
  79. else:
  80. first = True
  81. async for line in response.iter_lines():
  82. if line.startswith(b"data: "):
  83. chunk = line[6:]
  84. if chunk == b"[DONE]":
  85. break
  86. data = json.loads(chunk)
  87. cls.raise_error(data)
  88. choice = data["choices"][0]
  89. if "content" in choice["delta"] and choice["delta"]["content"]:
  90. delta = choice["delta"]["content"]
  91. if first:
  92. delta = delta.lstrip()
  93. if delta:
  94. first = False
  95. yield delta
  96. finish = cls.read_finish_reason(choice)
  97. if finish is not None:
  98. yield finish
  99. @staticmethod
  100. def read_finish_reason(choice: dict) -> Optional[FinishReason]:
  101. if "finish_reason" in choice and choice["finish_reason"] is not None:
  102. return FinishReason(choice["finish_reason"])
  103. @staticmethod
  104. def raise_error(data: dict):
  105. if "error_message" in data:
  106. raise ResponseError(data["error_message"])
  107. elif "error" in data:
  108. raise ResponseError(f'Error {data["error"]["code"]}: {data["error"]["message"]}')
  109. @classmethod
  110. def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict:
  111. return {
  112. "Accept": "text/event-stream" if stream else "application/json",
  113. "Content-Type": "application/json",
  114. **(
  115. {"Authorization": f"Bearer {api_key}"}
  116. if api_key is not None else {}
  117. ),
  118. **({} if headers is None else headers)
  119. }