DeepInfra.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. from __future__ import annotations
  2. import requests
  3. from ...typing import AsyncResult, Messages
  4. from ...requests import StreamSession, raise_for_status
  5. from ...providers.response import ImageResponse
  6. from ..template import OpenaiTemplate
  7. from ..helper import format_image_prompt
  8. class DeepInfra(OpenaiTemplate):
  9. url = "https://deepinfra.com"
  10. login_url = "https://deepinfra.com/dash/api_keys"
  11. api_base = "https://api.deepinfra.com/v1/openai"
  12. working = True
  13. needs_auth = True
  14. default_model = "meta-llama/Meta-Llama-3.1-70B-Instruct"
  15. default_image_model = "stabilityai/sd3.5"
  16. @classmethod
  17. def get_models(cls, **kwargs):
  18. if not cls.models:
  19. url = 'https://api.deepinfra.com/models/featured'
  20. response = requests.get(url)
  21. models = response.json()
  22. cls.models = []
  23. cls.image_models = []
  24. for model in models:
  25. if model["type"] == "text-generation":
  26. cls.models.append(model['model_name'])
  27. elif model["reported_type"] == "text-to-image":
  28. cls.image_models.append(model['model_name'])
  29. cls.models.extend(cls.image_models)
  30. return cls.models
  31. @classmethod
  32. def get_image_models(cls, **kwargs):
  33. if not cls.image_models:
  34. cls.get_models()
  35. return cls.image_models
  36. @classmethod
  37. async def create_async_generator(
  38. cls,
  39. model: str,
  40. messages: Messages,
  41. stream: bool,
  42. prompt: str = None,
  43. temperature: float = 0.7,
  44. max_tokens: int = 1028,
  45. **kwargs
  46. ) -> AsyncResult:
  47. if model in cls.get_image_models():
  48. yield cls.create_async_image(
  49. format_image_prompt(messages, prompt),
  50. model,
  51. **kwargs
  52. )
  53. return
  54. headers = {
  55. 'Accept-Encoding': 'gzip, deflate, br',
  56. 'Accept-Language': 'en-US',
  57. 'Origin': 'https://deepinfra.com',
  58. 'Referer': 'https://deepinfra.com/',
  59. 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36',
  60. 'X-Deepinfra-Source': 'web-embed',
  61. }
  62. async for chunk in super().create_async_generator(
  63. model, messages,
  64. stream=stream,
  65. temperature=temperature,
  66. max_tokens=max_tokens,
  67. headers=headers,
  68. **kwargs
  69. ):
  70. yield chunk
  71. @classmethod
  72. async def create_async_image(
  73. cls,
  74. prompt: str,
  75. model: str,
  76. api_key: str = None,
  77. api_base: str = "https://api.deepinfra.com/v1/inference",
  78. proxy: str = None,
  79. timeout: int = 180,
  80. extra_data: dict = {},
  81. **kwargs
  82. ) -> ImageResponse:
  83. headers = {
  84. 'Accept-Encoding': 'gzip, deflate, br',
  85. 'Accept-Language': 'en-US',
  86. 'Connection': 'keep-alive',
  87. 'Origin': 'https://deepinfra.com',
  88. 'Referer': 'https://deepinfra.com/',
  89. 'Sec-Fetch-Dest': 'empty',
  90. 'Sec-Fetch-Mode': 'cors',
  91. 'Sec-Fetch-Site': 'same-site',
  92. 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36',
  93. 'X-Deepinfra-Source': 'web-embed',
  94. 'sec-ch-ua': '"Google Chrome";v="119", "Chromium";v="119", "Not?A_Brand";v="24"',
  95. 'sec-ch-ua-mobile': '?0',
  96. 'sec-ch-ua-platform': '"macOS"',
  97. }
  98. if api_key is not None:
  99. headers["Authorization"] = f"Bearer {api_key}"
  100. async with StreamSession(
  101. proxies={"all": proxy},
  102. headers=headers,
  103. timeout=timeout
  104. ) as session:
  105. model = cls.get_model(model)
  106. data = {"prompt": prompt, **extra_data}
  107. data = {"input": data} if model == cls.default_model else data
  108. async with session.post(f"{api_base.rstrip('/')}/{model}", json=data) as response:
  109. await raise_for_status(response)
  110. data = await response.json()
  111. images = data.get("output", data.get("images", data.get("image_url")))
  112. if not images:
  113. raise RuntimeError(f"Response: {data}")
  114. images = images[0] if len(images) == 1 else images
  115. return ImageResponse(images, prompt)