DeepInfraChat.py 3.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. from __future__ import annotations
  2. from aiohttp import ClientSession
  3. import json
  4. from ..typing import AsyncResult, Messages, ImageType
  5. from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
  6. class DeepInfraChat(AsyncGeneratorProvider, ProviderModelMixin):
  7. url = "https://deepinfra.com/chat"
  8. api_endpoint = "https://api.deepinfra.com/v1/openai/chat/completions"
  9. working = True
  10. supports_stream = True
  11. supports_system_message = True
  12. supports_message_history = True
  13. default_model = 'meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo'
  14. models = [
  15. 'meta-llama/Meta-Llama-3.1-8B-Instruct',
  16. default_model,
  17. 'microsoft/WizardLM-2-8x22B',
  18. 'Qwen/Qwen2.5-72B-Instruct',
  19. ]
  20. model_aliases = {
  21. "llama-3.1-8b": "meta-llama/Meta-Llama-3.1-8B-Instruct",
  22. "llama-3.1-70b": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
  23. "wizardlm-2-8x22b": "microsoft/WizardLM-2-8x22B",
  24. "qwen-2-72b": "Qwen/Qwen2.5-72B-Instruct",
  25. }
  26. @classmethod
  27. def get_model(cls, model: str) -> str:
  28. if model in cls.models:
  29. return model
  30. elif model in cls.model_aliases:
  31. return cls.model_aliases[model]
  32. else:
  33. return cls.default_model
  34. @classmethod
  35. async def create_async_generator(
  36. cls,
  37. model: str,
  38. messages: Messages,
  39. proxy: str = None,
  40. image: ImageType = None,
  41. image_name: str = None,
  42. **kwargs
  43. ) -> AsyncResult:
  44. model = cls.get_model(model)
  45. headers = {
  46. 'Accept-Language': 'en-US,en;q=0.9',
  47. 'Cache-Control': 'no-cache',
  48. 'Connection': 'keep-alive',
  49. 'Content-Type': 'application/json',
  50. 'Origin': 'https://deepinfra.com',
  51. 'Pragma': 'no-cache',
  52. 'Referer': 'https://deepinfra.com/',
  53. 'Sec-Fetch-Dest': 'empty',
  54. 'Sec-Fetch-Mode': 'cors',
  55. 'Sec-Fetch-Site': 'same-site',
  56. 'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36',
  57. 'X-Deepinfra-Source': 'web-embed',
  58. 'accept': 'text/event-stream',
  59. 'sec-ch-ua': '"Not;A=Brand";v="24", "Chromium";v="128"',
  60. 'sec-ch-ua-mobile': '?0',
  61. 'sec-ch-ua-platform': '"Linux"',
  62. }
  63. async with ClientSession(headers=headers) as session:
  64. data = {
  65. 'model': model,
  66. 'messages': messages,
  67. 'stream': True
  68. }
  69. async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
  70. response.raise_for_status()
  71. async for line in response.content:
  72. if line:
  73. decoded_line = line.decode('utf-8').strip()
  74. if decoded_line.startswith('data:'):
  75. json_part = decoded_line[5:].strip()
  76. if json_part == '[DONE]':
  77. break
  78. try:
  79. data = json.loads(json_part)
  80. choices = data.get('choices', [])
  81. if choices:
  82. delta = choices[0].get('delta', {})
  83. content = delta.get('content', '')
  84. if content:
  85. yield content
  86. except json.JSONDecodeError:
  87. print(f"JSON decode error: {json_part}")