DeepInfraChat.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. from __future__ import annotations
  2. from aiohttp import ClientSession
  3. import json
  4. from ..typing import AsyncResult, Messages, ImageType
  5. from ..image import to_data_uri
  6. from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
  7. class DeepInfraChat(AsyncGeneratorProvider, ProviderModelMixin):
  8. url = "https://deepinfra.com/chat"
  9. api_endpoint = "https://api.deepinfra.com/v1/openai/chat/completions"
  10. working = True
  11. supports_stream = True
  12. supports_system_message = True
  13. supports_message_history = True
  14. default_model = 'meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo'
  15. models = [
  16. 'meta-llama/Meta-Llama-3.1-8B-Instruct',
  17. default_model,
  18. 'microsoft/WizardLM-2-8x22B',
  19. 'Qwen/Qwen2.5-72B-Instruct',
  20. ]
  21. model_aliases = {
  22. "llama-3.1-8b": "meta-llama/Meta-Llama-3.1-8B-Instruct",
  23. "llama-3.1-70b": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
  24. "wizardlm-2-8x22b": "microsoft/WizardLM-2-8x22B",
  25. "qwen-2-72b": "Qwen/Qwen2.5-72B-Instruct",
  26. }
  27. @classmethod
  28. def get_model(cls, model: str) -> str:
  29. if model in cls.models:
  30. return model
  31. elif model in cls.model_aliases:
  32. return cls.model_aliases[model]
  33. else:
  34. return cls.default_model
  35. @classmethod
  36. async def create_async_generator(
  37. cls,
  38. model: str,
  39. messages: Messages,
  40. proxy: str = None,
  41. image: ImageType = None,
  42. image_name: str = None,
  43. **kwargs
  44. ) -> AsyncResult:
  45. model = cls.get_model(model)
  46. headers = {
  47. 'Accept-Language': 'en-US,en;q=0.9',
  48. 'Cache-Control': 'no-cache',
  49. 'Connection': 'keep-alive',
  50. 'Content-Type': 'application/json',
  51. 'Origin': 'https://deepinfra.com',
  52. 'Pragma': 'no-cache',
  53. 'Referer': 'https://deepinfra.com/',
  54. 'Sec-Fetch-Dest': 'empty',
  55. 'Sec-Fetch-Mode': 'cors',
  56. 'Sec-Fetch-Site': 'same-site',
  57. 'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36',
  58. 'X-Deepinfra-Source': 'web-embed',
  59. 'accept': 'text/event-stream',
  60. 'sec-ch-ua': '"Not;A=Brand";v="24", "Chromium";v="128"',
  61. 'sec-ch-ua-mobile': '?0',
  62. 'sec-ch-ua-platform': '"Linux"',
  63. }
  64. async with ClientSession(headers=headers) as session:
  65. data = {
  66. 'model': model,
  67. 'messages': messages,
  68. 'stream': True
  69. }
  70. async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
  71. response.raise_for_status()
  72. async for line in response.content:
  73. if line:
  74. decoded_line = line.decode('utf-8').strip()
  75. if decoded_line.startswith('data:'):
  76. json_part = decoded_line[5:].strip()
  77. if json_part == '[DONE]':
  78. break
  79. try:
  80. data = json.loads(json_part)
  81. choices = data.get('choices', [])
  82. if choices:
  83. delta = choices[0].get('delta', {})
  84. content = delta.get('content', '')
  85. if content:
  86. yield content
  87. except json.JSONDecodeError:
  88. print(f"JSON decode error: {json_part}")