__init__.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504
  1. from __future__ import annotations
  2. import os
  3. import time
  4. import random
  5. import string
  6. import asyncio
  7. import base64
  8. import logging
  9. from typing import Union, AsyncIterator, Iterator, Coroutine, Optional
  10. from ..providers.base_provider import AsyncGeneratorProvider
  11. from ..image import ImageResponse, copy_images, images_dir
  12. from ..typing import Messages, Image, ImageType
  13. from ..providers.types import ProviderType
  14. from ..providers.response import ResponseType, FinishReason, BaseConversation, SynthesizeData
  15. from ..errors import NoImageResponseError, ModelNotFoundError
  16. from ..providers.retry_provider import IterListProvider
  17. from ..providers.base_provider import get_running_loop
  18. from ..Provider.needs_auth.BingCreateImages import BingCreateImages
  19. from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
  20. from .image_models import ImageModels
  21. from .types import IterResponse, ImageProvider, Client as BaseClient
  22. from .service import get_model_and_provider, get_last_provider, convert_to_provider
  23. from .helper import find_stop, filter_json, filter_none, safe_aclose, to_sync_iter, to_async_iterator
  24. ChatCompletionResponseType = Iterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]]
  25. AsyncChatCompletionResponseType = AsyncIterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]]
  26. try:
  27. anext # Python 3.8+
  28. except NameError:
  29. async def anext(aiter):
  30. try:
  31. return await aiter.__anext__()
  32. except StopAsyncIteration:
  33. raise StopIteration
  34. # Synchronous iter_response function
  35. def iter_response(
  36. response: Union[Iterator[Union[str, ResponseType]]],
  37. stream: bool,
  38. response_format: Optional[dict] = None,
  39. max_tokens: Optional[int] = None,
  40. stop: Optional[list[str]] = None
  41. ) -> ChatCompletionResponseType:
  42. content = ""
  43. finish_reason = None
  44. completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
  45. idx = 0
  46. if hasattr(response, '__aiter__'):
  47. # It's an async iterator, wrap it into a sync iterator
  48. response = to_sync_iter(response)
  49. for chunk in response:
  50. if isinstance(chunk, FinishReason):
  51. finish_reason = chunk.reason
  52. break
  53. elif isinstance(chunk, BaseConversation):
  54. yield chunk
  55. continue
  56. elif isinstance(chunk, SynthesizeData):
  57. continue
  58. chunk = str(chunk)
  59. content += chunk
  60. if max_tokens is not None and idx + 1 >= max_tokens:
  61. finish_reason = "length"
  62. first, content, chunk = find_stop(stop, content, chunk if stream else None)
  63. if first != -1:
  64. finish_reason = "stop"
  65. if stream:
  66. yield ChatCompletionChunk(chunk, None, completion_id, int(time.time()))
  67. if finish_reason is not None:
  68. break
  69. idx += 1
  70. finish_reason = "stop" if finish_reason is None else finish_reason
  71. if stream:
  72. yield ChatCompletionChunk(None, finish_reason, completion_id, int(time.time()))
  73. else:
  74. if response_format is not None and "type" in response_format:
  75. if response_format["type"] == "json_object":
  76. content = filter_json(content)
  77. yield ChatCompletion(content, finish_reason, completion_id, int(time.time()))
  78. # Synchronous iter_append_model_and_provider function
  79. def iter_append_model_and_provider(response: ChatCompletionResponseType) -> ChatCompletionResponseType:
  80. last_provider = None
  81. for chunk in response:
  82. if isinstance(chunk, (ChatCompletion, ChatCompletionChunk)):
  83. last_provider = get_last_provider(True) if last_provider is None else last_provider
  84. chunk.model = last_provider.get("model")
  85. chunk.provider = last_provider.get("name")
  86. yield chunk
  87. async def async_iter_response(
  88. response: AsyncIterator[Union[str, ResponseType]],
  89. stream: bool,
  90. response_format: Optional[dict] = None,
  91. max_tokens: Optional[int] = None,
  92. stop: Optional[list[str]] = None
  93. ) -> AsyncChatCompletionResponseType:
  94. content = ""
  95. finish_reason = None
  96. completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
  97. idx = 0
  98. try:
  99. async for chunk in response:
  100. if isinstance(chunk, FinishReason):
  101. finish_reason = chunk.reason
  102. break
  103. elif isinstance(chunk, BaseConversation):
  104. yield chunk
  105. continue
  106. elif isinstance(chunk, SynthesizeData):
  107. continue
  108. chunk = str(chunk)
  109. content += chunk
  110. idx += 1
  111. if max_tokens is not None and idx >= max_tokens:
  112. finish_reason = "length"
  113. first, content, chunk = find_stop(stop, content, chunk if stream else None)
  114. if first != -1:
  115. finish_reason = "stop"
  116. if stream:
  117. yield ChatCompletionChunk(chunk, None, completion_id, int(time.time()))
  118. if finish_reason is not None:
  119. break
  120. finish_reason = "stop" if finish_reason is None else finish_reason
  121. if stream:
  122. yield ChatCompletionChunk(None, finish_reason, completion_id, int(time.time()))
  123. else:
  124. if response_format is not None and "type" in response_format:
  125. if response_format["type"] == "json_object":
  126. content = filter_json(content)
  127. yield ChatCompletion(content, finish_reason, completion_id, int(time.time()))
  128. finally:
  129. if hasattr(response, 'aclose'):
  130. await safe_aclose(response)
  131. async def async_iter_append_model_and_provider(
  132. response: AsyncChatCompletionResponseType
  133. ) -> AsyncChatCompletionResponseType:
  134. last_provider = None
  135. try:
  136. async for chunk in response:
  137. if isinstance(chunk, (ChatCompletion, ChatCompletionChunk)):
  138. last_provider = get_last_provider(True) if last_provider is None else last_provider
  139. chunk.model = last_provider.get("model")
  140. chunk.provider = last_provider.get("name")
  141. yield chunk
  142. finally:
  143. if hasattr(response, 'aclose'):
  144. await safe_aclose(response)
  145. class Client(BaseClient):
  146. def __init__(
  147. self,
  148. provider: Optional[ProviderType] = None,
  149. image_provider: Optional[ImageProvider] = None,
  150. **kwargs
  151. ) -> None:
  152. super().__init__(**kwargs)
  153. self.chat: Chat = Chat(self, provider)
  154. self.images: Images = Images(self, image_provider)
  155. class Completions:
  156. def __init__(self, client: Client, provider: Optional[ProviderType] = None):
  157. self.client: Client = client
  158. self.provider: ProviderType = provider
  159. def create(
  160. self,
  161. messages: Messages,
  162. model: str,
  163. provider: Optional[ProviderType] = None,
  164. stream: Optional[bool] = False,
  165. proxy: Optional[str] = None,
  166. response_format: Optional[dict] = None,
  167. max_tokens: Optional[int] = None,
  168. stop: Optional[Union[list[str], str]] = None,
  169. api_key: Optional[str] = None,
  170. ignored: Optional[list[str]] = None,
  171. ignore_working: Optional[bool] = False,
  172. ignore_stream: Optional[bool] = False,
  173. **kwargs
  174. ) -> IterResponse:
  175. model, provider = get_model_and_provider(
  176. model,
  177. self.provider if provider is None else provider,
  178. stream,
  179. ignored,
  180. ignore_working,
  181. ignore_stream,
  182. )
  183. stop = [stop] if isinstance(stop, str) else stop
  184. response = provider.create_completion(
  185. model,
  186. messages,
  187. stream=stream,
  188. **filter_none(
  189. proxy=self.client.proxy if proxy is None else proxy,
  190. max_tokens=max_tokens,
  191. stop=stop,
  192. api_key=self.client.api_key if api_key is None else api_key
  193. ),
  194. **kwargs
  195. )
  196. if asyncio.iscoroutinefunction(provider.create_completion):
  197. # Run the asynchronous function in an event loop
  198. response = asyncio.run(response)
  199. if stream and hasattr(response, '__aiter__'):
  200. # It's an async generator, wrap it into a sync iterator
  201. response = to_sync_iter(response)
  202. elif hasattr(response, '__aiter__'):
  203. # If response is an async generator, collect it into a list
  204. response = list(to_sync_iter(response))
  205. response = iter_response(response, stream, response_format, max_tokens, stop)
  206. response = iter_append_model_and_provider(response)
  207. if stream:
  208. return response
  209. else:
  210. return next(response)
  211. class Chat:
  212. completions: Completions
  213. def __init__(self, client: Client, provider: Optional[ProviderType] = None):
  214. self.completions = Completions(client, provider)
  215. class Images:
  216. def __init__(self, client: Client, provider: Optional[ProviderType] = None):
  217. self.client: Client = client
  218. self.provider: Optional[ProviderType] = provider
  219. self.models: ImageModels = ImageModels(client)
  220. def generate(
  221. self,
  222. prompt: str,
  223. model: str = None,
  224. provider: Optional[ProviderType] = None,
  225. response_format: str = "url",
  226. proxy: Optional[str] = None,
  227. **kwargs
  228. ) -> ImagesResponse:
  229. """
  230. Synchronous generate method that runs the async_generate method in an event loop.
  231. """
  232. return asyncio.run(self.async_generate(prompt, model, provider, response_format, proxy, **kwargs))
  233. async def async_generate(
  234. self,
  235. prompt: str,
  236. model: Optional[str] = None,
  237. provider: Optional[ProviderType] = None,
  238. response_format: Optional[str] = "url",
  239. proxy: Optional[str] = None,
  240. **kwargs
  241. ) -> ImagesResponse:
  242. if provider is None:
  243. provider_handler = self.models.get(model, provider or self.provider or BingCreateImages)
  244. elif isinstance(provider, str):
  245. provider_handler = convert_to_provider(provider)
  246. else:
  247. provider_handler = provider
  248. if provider_handler is None:
  249. raise ModelNotFoundError(f"Unknown model: {model}")
  250. if isinstance(provider_handler, IterListProvider):
  251. if provider_handler.providers:
  252. provider_handler = provider_handler.providers[0]
  253. else:
  254. raise ModelNotFoundError(f"IterListProvider for model {model} has no providers")
  255. if proxy is None:
  256. proxy = self.client.proxy
  257. response = None
  258. if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider):
  259. messages = [{"role": "user", "content": f"Generate a image: {prompt}"}]
  260. async for item in provider_handler.create_async_generator(model, messages, prompt=prompt, **kwargs):
  261. if isinstance(item, ImageResponse):
  262. response = item
  263. break
  264. elif hasattr(provider_handler, 'create'):
  265. if asyncio.iscoroutinefunction(provider_handler.create):
  266. response = await provider_handler.create(prompt)
  267. else:
  268. response = provider_handler.create(prompt)
  269. if isinstance(response, str):
  270. response = ImageResponse([response], prompt)
  271. elif hasattr(provider_handler, "create_completion"):
  272. get_running_loop(check_nested=True)
  273. messages = [{"role": "user", "content": f"Generate a image: {prompt}"}]
  274. for item in provider_handler.create_completion(model, messages, prompt=prompt, **kwargs):
  275. if isinstance(item, ImageResponse):
  276. response = item
  277. break
  278. else:
  279. raise ValueError(f"Provider {provider} does not support image generation")
  280. if isinstance(response, ImageResponse):
  281. return await self._process_image_response(
  282. response,
  283. response_format,
  284. proxy,
  285. model,
  286. getattr(provider_handler, "__name__", None)
  287. )
  288. raise NoImageResponseError(f"Unexpected response type: {type(response)}")
  289. def create_variation(
  290. self,
  291. image: Union[str, bytes],
  292. model: str = None,
  293. provider: Optional[ProviderType] = None,
  294. response_format: str = "url",
  295. **kwargs
  296. ) -> ImagesResponse:
  297. return asyncio.run(self.async_create_variation(
  298. image, model, provider, response_format, **kwargs
  299. ))
  300. async def async_create_variation(
  301. self,
  302. image: ImageType,
  303. model: Optional[str] = None,
  304. provider: Optional[ProviderType] = None,
  305. response_format: str = "url",
  306. proxy: Optional[str] = None,
  307. **kwargs
  308. ) -> ImagesResponse:
  309. if provider is None:
  310. provider = self.models.get(model, provider or self.provider or BingCreateImages)
  311. if provider is None:
  312. raise ModelNotFoundError(f"Unknown model: {model}")
  313. if isinstance(provider, str):
  314. provider = convert_to_provider(provider)
  315. if proxy is None:
  316. proxy = self.client.proxy
  317. if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider):
  318. messages = [{"role": "user", "content": "create a variation of this image"}]
  319. generator = None
  320. try:
  321. generator = provider.create_async_generator(model, messages, image=image, response_format=response_format, proxy=proxy, **kwargs)
  322. async for chunk in generator:
  323. if isinstance(chunk, ImageResponse):
  324. response = chunk
  325. break
  326. finally:
  327. if generator and hasattr(generator, 'aclose'):
  328. await safe_aclose(generator)
  329. elif hasattr(provider, 'create_variation'):
  330. if asyncio.iscoroutinefunction(provider.create_variation):
  331. response = await provider.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs)
  332. else:
  333. response = provider.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs)
  334. else:
  335. raise NoImageResponseError(f"Provider {provider} does not support image variation")
  336. if isinstance(response, str):
  337. response = ImageResponse([response])
  338. if isinstance(response, ImageResponse):
  339. return self._process_image_response(response, response_format, proxy, model, getattr(provider, "__name__", None))
  340. raise NoImageResponseError(f"Unexpected response type: {type(response)}")
  341. async def _process_image_response(
  342. self,
  343. response: ImageResponse,
  344. response_format: str,
  345. proxy: str = None,
  346. model: Optional[str] = None,
  347. provider: Optional[str] = None
  348. ) -> list[Image]:
  349. if response_format in ("url", "b64_json"):
  350. images = await copy_images(response.get_list(), response.options.get("cookies"), proxy)
  351. async def process_image_item(image_file: str) -> Image:
  352. if response_format == "b64_json":
  353. with open(os.path.join(images_dir, os.path.basename(image_file)), "rb") as file:
  354. image_data = base64.b64encode(file.read()).decode()
  355. return Image(url=image_file, b64_json=image_data, revised_prompt=response.alt)
  356. return Image(url=image_file, revised_prompt=response.alt)
  357. images = await asyncio.gather(*[process_image_item(image) for image in images])
  358. else:
  359. images = [Image(url=image, revised_prompt=response.alt) for image in response.get_list()]
  360. last_provider = get_last_provider(True)
  361. return ImagesResponse(
  362. images,
  363. model=last_provider.get("model") if model is None else model,
  364. provider=last_provider.get("name") if provider is None else provider
  365. )
  366. class AsyncClient(BaseClient):
  367. def __init__(
  368. self,
  369. provider: Optional[ProviderType] = None,
  370. image_provider: Optional[ImageProvider] = None,
  371. **kwargs
  372. ) -> None:
  373. super().__init__(**kwargs)
  374. self.chat: AsyncChat = AsyncChat(self, provider)
  375. self.images: AsyncImages = AsyncImages(self, image_provider)
  376. class AsyncChat:
  377. completions: AsyncCompletions
  378. def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
  379. self.completions = AsyncCompletions(client, provider)
  380. class AsyncCompletions:
  381. def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
  382. self.client: AsyncClient = client
  383. self.provider: ProviderType = provider
  384. def create(
  385. self,
  386. messages: Messages,
  387. model: str,
  388. provider: Optional[ProviderType] = None,
  389. stream: Optional[bool] = False,
  390. proxy: Optional[str] = None,
  391. response_format: Optional[dict] = None,
  392. max_tokens: Optional[int] = None,
  393. stop: Optional[Union[list[str], str]] = None,
  394. api_key: Optional[str] = None,
  395. ignored: Optional[list[str]] = None,
  396. ignore_working: Optional[bool] = False,
  397. ignore_stream: Optional[bool] = False,
  398. **kwargs
  399. ) -> Union[Coroutine[ChatCompletion], AsyncIterator[ChatCompletionChunk, BaseConversation]]:
  400. model, provider = get_model_and_provider(
  401. model,
  402. self.provider if provider is None else provider,
  403. stream,
  404. ignored,
  405. ignore_working,
  406. ignore_stream,
  407. )
  408. stop = [stop] if isinstance(stop, str) else stop
  409. response = provider.create_completion(
  410. model,
  411. messages,
  412. stream=stream,
  413. **filter_none(
  414. proxy=self.client.proxy if proxy is None else proxy,
  415. max_tokens=max_tokens,
  416. stop=stop,
  417. api_key=self.client.api_key if api_key is None else api_key
  418. ),
  419. **kwargs
  420. )
  421. if not isinstance(response, AsyncIterator):
  422. response = to_async_iterator(response)
  423. response = async_iter_response(response, stream, response_format, max_tokens, stop)
  424. response = async_iter_append_model_and_provider(response)
  425. return response if stream else anext(response)
  426. class AsyncImages(Images):
  427. def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
  428. self.client: AsyncClient = client
  429. self.provider: Optional[ProviderType] = provider
  430. self.models: ImageModels = ImageModels(client)
  431. async def generate(
  432. self,
  433. prompt: str,
  434. model: Optional[str] = None,
  435. provider: Optional[ProviderType] = None,
  436. response_format: str = "url",
  437. **kwargs
  438. ) -> ImagesResponse:
  439. return await self.async_generate(prompt, model, provider, response_format, **kwargs)
  440. async def create_variation(
  441. self,
  442. image: ImageType,
  443. model: str = None,
  444. provider: ProviderType = None,
  445. response_format: str = "url",
  446. **kwargs
  447. ) -> ImagesResponse:
  448. return await self.async_create_variation(
  449. image, model, provider, response_format, **kwargs
  450. )