PollinationsAI.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. from __future__ import annotations
  2. import json
  3. import random
  4. import requests
  5. from urllib.parse import quote_plus
  6. from typing import Optional
  7. from aiohttp import ClientSession
  8. from .helper import filter_none, format_image_prompt
  9. from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
  10. from ..typing import AsyncResult, Messages, ImagesType
  11. from ..image import to_data_uri
  12. from ..errors import ModelNotFoundError
  13. from ..requests.raise_for_status import raise_for_status
  14. from ..requests.aiohttp import get_connector
  15. from ..providers.response import ImageResponse, ImagePreview, FinishReason, Usage, Reasoning
  16. DEFAULT_HEADERS = {
  17. 'Accept': '*/*',
  18. 'Accept-Language': 'en-US,en;q=0.9',
  19. 'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36',
  20. }
  21. class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
  22. label = "Pollinations AI"
  23. url = "https://pollinations.ai"
  24. working = True
  25. supports_stream = False
  26. supports_system_message = True
  27. supports_message_history = True
  28. # API endpoints
  29. text_api_endpoint = "https://text.pollinations.ai/openai"
  30. image_api_endpoint = "https://image.pollinations.ai/"
  31. # Models configuration
  32. default_model = "openai"
  33. default_image_model = "flux"
  34. default_vision_model = "gpt-4o"
  35. extra_image_models = ["flux-pro", "flux-dev", "flux-schnell", "midjourney", "dall-e-3"]
  36. vision_models = [default_vision_model, "gpt-4o-mini"]
  37. extra_text_models = ["claude", "claude-email", "deepseek-reasoner", "deepseek-r1"] + vision_models
  38. model_aliases = {
  39. ### Text Models ###
  40. "gpt-4o-mini": "openai",
  41. "gpt-4": "openai-large",
  42. "gpt-4o": "openai-large",
  43. "qwen-2.5-coder-32b": "qwen-coder",
  44. "llama-3.3-70b": "llama",
  45. "mistral-nemo": "mistral",
  46. "gpt-4o-mini": "rtist",
  47. "gpt-4o": "searchgpt",
  48. "gpt-4o-mini": "p1",
  49. "deepseek-chat": "claude-hybridspace",
  50. "llama-3.1-8b": "llamalight",
  51. "gpt-4o-vision": "gpt-4o",
  52. "gpt-4o-mini-vision": "gpt-4o-mini",
  53. "gpt-4o-mini": "claude",
  54. "deepseek-chat": "claude-email",
  55. "deepseek-r1": "deepseek-reasoner",
  56. "gemini-2.0-flash": "gemini",
  57. "gemini-2.0-flash-thinking": "gemini-thinking",
  58. ### Image Models ###
  59. "sdxl-turbo": "turbo",
  60. }
  61. text_models = []
  62. image_models = []
  63. @classmethod
  64. def get_models(cls, **kwargs):
  65. if not cls.text_models or not cls.image_models:
  66. try:
  67. image_response = requests.get("https://image.pollinations.ai/models")
  68. image_response.raise_for_status()
  69. new_image_models = image_response.json()
  70. cls.image_models = list(dict.fromkeys([*cls.extra_image_models, *new_image_models]))
  71. text_response = requests.get("https://text.pollinations.ai/models")
  72. text_response.raise_for_status()
  73. original_text_models = [model.get("name") for model in text_response.json()]
  74. combined_text = cls.extra_text_models + [
  75. model for model in original_text_models
  76. if model not in cls.extra_text_models
  77. ]
  78. cls.text_models = list(dict.fromkeys(combined_text))
  79. except Exception as e:
  80. raise RuntimeError(f"Failed to fetch models: {e}") from e
  81. return cls.text_models + cls.image_models
  82. @classmethod
  83. async def create_async_generator(
  84. cls,
  85. model: str,
  86. messages: Messages,
  87. proxy: str = None,
  88. prompt: str = None,
  89. width: int = 1024,
  90. height: int = 1024,
  91. seed: Optional[int] = None,
  92. nologo: bool = True,
  93. private: bool = False,
  94. enhance: bool = False,
  95. safe: bool = False,
  96. images: ImagesType = None,
  97. temperature: float = None,
  98. presence_penalty: float = None,
  99. top_p: float = 1,
  100. frequency_penalty: float = None,
  101. response_format: Optional[dict] = None,
  102. cache: bool = False,
  103. **kwargs
  104. ) -> AsyncResult:
  105. if images is not None and not model:
  106. model = cls.default_vision_model
  107. try:
  108. model = cls.get_model(model)
  109. except ModelNotFoundError:
  110. if model not in cls.image_models:
  111. raise
  112. if not cache and seed is None:
  113. seed = random.randint(0, 10000)
  114. if model in cls.image_models:
  115. async for chunk in cls._generate_image(
  116. model=model,
  117. prompt=format_image_prompt(messages, prompt),
  118. proxy=proxy,
  119. width=width,
  120. height=height,
  121. seed=seed,
  122. nologo=nologo,
  123. private=private,
  124. enhance=enhance,
  125. safe=safe
  126. ):
  127. yield chunk
  128. else:
  129. async for result in cls._generate_text(
  130. model=model,
  131. messages=messages,
  132. images=images,
  133. proxy=proxy,
  134. temperature=temperature,
  135. presence_penalty=presence_penalty,
  136. top_p=top_p,
  137. frequency_penalty=frequency_penalty,
  138. response_format=response_format,
  139. seed=seed,
  140. cache=cache,
  141. ):
  142. yield result
  143. @classmethod
  144. async def _generate_image(
  145. cls,
  146. model: str,
  147. prompt: str,
  148. proxy: str,
  149. width: int,
  150. height: int,
  151. seed: Optional[int],
  152. nologo: bool,
  153. private: bool,
  154. enhance: bool,
  155. safe: bool
  156. ) -> AsyncResult:
  157. params = {
  158. "seed": str(seed) if seed is not None else None,
  159. "width": str(width),
  160. "height": str(height),
  161. "model": model,
  162. "nologo": str(nologo).lower(),
  163. "private": str(private).lower(),
  164. "enhance": str(enhance).lower(),
  165. "safe": str(safe).lower()
  166. }
  167. params = {k: v for k, v in params.items() if v is not None}
  168. query = "&".join(f"{k}={quote_plus(v)}" for k, v in params.items())
  169. url = f"{cls.image_api_endpoint}prompt/{quote_plus(prompt)}?{query}"
  170. yield ImagePreview(url, prompt)
  171. async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
  172. async with session.get(url, allow_redirects=True) as response:
  173. await raise_for_status(response)
  174. image_url = str(response.url)
  175. yield ImageResponse(image_url, prompt)
  176. @classmethod
  177. async def _generate_text(
  178. cls,
  179. model: str,
  180. messages: Messages,
  181. images: Optional[ImagesType],
  182. proxy: str,
  183. temperature: float,
  184. presence_penalty: float,
  185. top_p: float,
  186. frequency_penalty: float,
  187. response_format: Optional[dict],
  188. seed: Optional[int],
  189. cache: bool
  190. ) -> AsyncResult:
  191. json_mode = False
  192. if response_format and response_format.get("type") == "json_object":
  193. json_mode = True
  194. if images and messages:
  195. last_message = messages[-1].copy()
  196. image_content = [
  197. {
  198. "type": "image_url",
  199. "image_url": {"url": to_data_uri(image)}
  200. }
  201. for image, _ in images
  202. ]
  203. last_message["content"] = image_content + [{"type": "text", "text": last_message["content"]}]
  204. messages[-1] = last_message
  205. async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
  206. data = filter_none(**{
  207. "messages": messages,
  208. "model": model,
  209. "temperature": temperature,
  210. "presence_penalty": presence_penalty,
  211. "top_p": top_p,
  212. "frequency_penalty": frequency_penalty,
  213. "jsonMode": json_mode,
  214. "stream": False,
  215. "seed": seed,
  216. "cache": cache
  217. })
  218. async with session.post(cls.text_api_endpoint, json=data) as response:
  219. await raise_for_status(response)
  220. result = await response.json()
  221. choice = result["choices"][0]
  222. message = choice.get("message", {})
  223. content = message.get("content", "")
  224. if content:
  225. yield content.replace("\\(", "(").replace("\\)", ")")
  226. if "usage" in result:
  227. yield Usage(**result["usage"])
  228. finish_reason = choice.get("finish_reason")
  229. if finish_reason:
  230. yield FinishReason(finish_reason)