DeepInfraImage.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. from __future__ import annotations
  2. import requests
  3. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
  4. from ...typing import AsyncResult, Messages
  5. from ...requests import StreamSession, raise_for_status
  6. from ...image import ImageResponse
  7. class DeepInfraImage(AsyncGeneratorProvider, ProviderModelMixin):
  8. url = "https://deepinfra.com"
  9. parent = "DeepInfra"
  10. working = True
  11. needs_auth = True
  12. default_model = ''
  13. image_models = [default_model]
  14. @classmethod
  15. def get_models(cls):
  16. if not cls.models:
  17. url = 'https://api.deepinfra.com/models/featured'
  18. models = requests.get(url).json()
  19. cls.models = [model['model_name'] for model in models if model["reported_type"] == "text-to-image"]
  20. cls.image_models = cls.models
  21. return cls.models
  22. @classmethod
  23. async def create_async_generator(
  24. cls,
  25. model: str,
  26. messages: Messages,
  27. prompt: str = None,
  28. **kwargs
  29. ) -> AsyncResult:
  30. yield await cls.create_async(messages[-1]["content"] if prompt is None else prompt, model, **kwargs)
  31. @classmethod
  32. async def create_async(
  33. cls,
  34. prompt: str,
  35. model: str,
  36. api_key: str = None,
  37. api_base: str = "https://api.deepinfra.com/v1/inference",
  38. proxy: str = None,
  39. timeout: int = 180,
  40. extra_data: dict = {},
  41. **kwargs
  42. ) -> ImageResponse:
  43. headers = {
  44. 'Accept-Encoding': 'gzip, deflate, br',
  45. 'Accept-Language': 'en-US',
  46. 'Connection': 'keep-alive',
  47. 'Origin': 'https://deepinfra.com',
  48. 'Referer': 'https://deepinfra.com/',
  49. 'Sec-Fetch-Dest': 'empty',
  50. 'Sec-Fetch-Mode': 'cors',
  51. 'Sec-Fetch-Site': 'same-site',
  52. 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36',
  53. 'X-Deepinfra-Source': 'web-embed',
  54. 'sec-ch-ua': '"Google Chrome";v="119", "Chromium";v="119", "Not?A_Brand";v="24"',
  55. 'sec-ch-ua-mobile': '?0',
  56. 'sec-ch-ua-platform': '"macOS"',
  57. }
  58. if api_key is not None:
  59. headers["Authorization"] = f"Bearer {api_key}"
  60. async with StreamSession(
  61. proxies={"all": proxy},
  62. headers=headers,
  63. timeout=timeout
  64. ) as session:
  65. model = cls.get_model(model)
  66. data = {"prompt": prompt, **extra_data}
  67. data = {"input": data} if model == cls.default_model else data
  68. async with session.post(f"{api_base.rstrip('/')}/{model}", json=data) as response:
  69. await raise_for_status(response)
  70. data = await response.json()
  71. images = data.get("output", data.get("images"))
  72. if not images:
  73. raise RuntimeError(f"Response: {data}")
  74. images = images[0] if len(images) == 1 else images
  75. return ImageResponse(images, prompt)