__init__.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616
  1. from __future__ import annotations
  2. import os
  3. import time
  4. import random
  5. import string
  6. import asyncio
  7. import aiohttp
  8. import base64
  9. from typing import Union, AsyncIterator, Iterator, Awaitable, Optional
  10. from ..image.copy_images import copy_images
  11. from ..typing import Messages, ImageType
  12. from ..providers.types import ProviderType, BaseRetryProvider
  13. from ..providers.response import ResponseType, ImageResponse, FinishReason, BaseConversation, SynthesizeData, ToolCalls, Usage
  14. from ..errors import NoImageResponseError
  15. from ..providers.retry_provider import IterListProvider
  16. from ..providers.asyncio import to_sync_generator
  17. from ..Provider.needs_auth import BingCreateImages, OpenaiAccount
  18. from ..tools.run_tools import async_iter_run_tools, iter_run_tools
  19. from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
  20. from .image_models import ImageModels
  21. from .types import IterResponse, ImageProvider, Client as BaseClient
  22. from .service import get_model_and_provider, convert_to_provider
  23. from .helper import find_stop, filter_json, filter_none, safe_aclose
  24. from .. import debug
  25. ChatCompletionResponseType = Iterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]]
  26. AsyncChatCompletionResponseType = AsyncIterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]]
  27. try:
  28. anext # Python 3.8+
  29. except NameError:
  30. async def anext(aiter):
  31. try:
  32. return await aiter.__anext__()
  33. except StopAsyncIteration:
  34. raise StopIteration
  35. # Synchronous iter_response function
  36. def iter_response(
  37. response: Union[Iterator[Union[str, ResponseType]]],
  38. stream: bool,
  39. response_format: Optional[dict] = None,
  40. max_tokens: Optional[int] = None,
  41. stop: Optional[list[str]] = None
  42. ) -> ChatCompletionResponseType:
  43. content = ""
  44. finish_reason = None
  45. tool_calls = None
  46. usage = None
  47. completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
  48. idx = 0
  49. if hasattr(response, '__aiter__'):
  50. response = to_sync_generator(response)
  51. for chunk in response:
  52. if isinstance(chunk, FinishReason):
  53. finish_reason = chunk.reason
  54. break
  55. elif isinstance(chunk, ToolCalls):
  56. tool_calls = chunk.get_list()
  57. continue
  58. elif isinstance(chunk, Usage):
  59. usage = chunk
  60. continue
  61. elif isinstance(chunk, BaseConversation):
  62. yield chunk
  63. continue
  64. elif isinstance(chunk, SynthesizeData) or not chunk:
  65. continue
  66. elif isinstance(chunk, Exception):
  67. continue
  68. if isinstance(chunk, list):
  69. chunk = "".join(map(str, chunk))
  70. else:
  71. temp = chunk.__str__()
  72. if not isinstance(temp, str):
  73. if isinstance(temp, list):
  74. temp = "".join(map(str, temp))
  75. else:
  76. temp = repr(chunk)
  77. chunk = temp
  78. content += chunk
  79. if max_tokens is not None and idx + 1 >= max_tokens:
  80. finish_reason = "length"
  81. first, content, chunk = find_stop(stop, content, chunk if stream else None)
  82. if first != -1:
  83. finish_reason = "stop"
  84. if stream:
  85. yield ChatCompletionChunk.model_construct(chunk, None, completion_id, int(time.time()))
  86. if finish_reason is not None:
  87. break
  88. idx += 1
  89. if usage is None:
  90. usage = Usage(prompt_tokens=0, completion_tokens=idx, total_tokens=idx)
  91. finish_reason = "stop" if finish_reason is None else finish_reason
  92. if stream:
  93. yield ChatCompletionChunk.model_construct(
  94. None, finish_reason, completion_id, int(time.time()),
  95. usage=usage.get_dict()
  96. )
  97. else:
  98. if response_format is not None and "type" in response_format:
  99. if response_format["type"] == "json_object":
  100. content = filter_json(content)
  101. yield ChatCompletion.model_construct(
  102. content, finish_reason, completion_id, int(time.time()),
  103. usage=usage.get_dict(), **filter_none(tool_calls=tool_calls)
  104. )
  105. # Synchronous iter_append_model_and_provider function
  106. def iter_append_model_and_provider(response: ChatCompletionResponseType, last_model: str, last_provider: ProviderType) -> ChatCompletionResponseType:
  107. if isinstance(last_provider, BaseRetryProvider):
  108. last_provider = last_provider.last_provider
  109. for chunk in response:
  110. if isinstance(chunk, (ChatCompletion, ChatCompletionChunk)):
  111. if last_provider is not None:
  112. chunk.model = getattr(last_provider, "last_model", last_model)
  113. chunk.provider = last_provider.__name__
  114. yield chunk
  115. async def async_iter_response(
  116. response: AsyncIterator[Union[str, ResponseType]],
  117. stream: bool,
  118. response_format: Optional[dict] = None,
  119. max_tokens: Optional[int] = None,
  120. stop: Optional[list[str]] = None
  121. ) -> AsyncChatCompletionResponseType:
  122. content = ""
  123. finish_reason = None
  124. completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
  125. idx = 0
  126. tool_calls = None
  127. usage = None
  128. try:
  129. async for chunk in response:
  130. if isinstance(chunk, FinishReason):
  131. finish_reason = chunk.reason
  132. break
  133. elif isinstance(chunk, BaseConversation):
  134. yield chunk
  135. continue
  136. elif isinstance(chunk, ToolCalls):
  137. tool_calls = chunk.get_list()
  138. continue
  139. elif isinstance(chunk, Usage):
  140. usage = chunk
  141. continue
  142. elif isinstance(chunk, SynthesizeData) or not chunk:
  143. continue
  144. elif isinstance(chunk, Exception):
  145. continue
  146. chunk = str(chunk)
  147. content += chunk
  148. idx += 1
  149. if max_tokens is not None and idx >= max_tokens:
  150. finish_reason = "length"
  151. first, content, chunk = find_stop(stop, content, chunk if stream else None)
  152. if first != -1:
  153. finish_reason = "stop"
  154. if stream:
  155. yield ChatCompletionChunk.model_construct(chunk, None, completion_id, int(time.time()))
  156. if finish_reason is not None:
  157. break
  158. finish_reason = "stop" if finish_reason is None else finish_reason
  159. if usage is None:
  160. usage = Usage(prompt_tokens=0, completion_tokens=idx, total_tokens=idx)
  161. if stream:
  162. yield ChatCompletionChunk.model_construct(
  163. None, finish_reason, completion_id, int(time.time()),
  164. usage=usage.get_dict()
  165. )
  166. else:
  167. if response_format is not None and "type" in response_format:
  168. if response_format["type"] == "json_object":
  169. content = filter_json(content)
  170. yield ChatCompletion.model_construct(
  171. content, finish_reason, completion_id, int(time.time()),
  172. usage=usage.get_dict(), **filter_none(tool_calls=tool_calls)
  173. )
  174. finally:
  175. await safe_aclose(response)
  176. async def async_iter_append_model_and_provider(
  177. response: AsyncChatCompletionResponseType,
  178. last_model: str,
  179. last_provider: ProviderType
  180. ) -> AsyncChatCompletionResponseType:
  181. last_provider = None
  182. try:
  183. if isinstance(last_provider, BaseRetryProvider):
  184. if last_provider is not None:
  185. last_provider = last_provider.last_provider
  186. async for chunk in response:
  187. if isinstance(chunk, (ChatCompletion, ChatCompletionChunk)):
  188. if last_provider is not None:
  189. chunk.model = getattr(last_provider, "last_model", last_model)
  190. chunk.provider = last_provider.__name__
  191. yield chunk
  192. finally:
  193. await safe_aclose(response)
  194. class Client(BaseClient):
  195. def __init__(
  196. self,
  197. provider: Optional[ProviderType] = None,
  198. image_provider: Optional[ImageProvider] = None,
  199. **kwargs
  200. ) -> None:
  201. super().__init__(**kwargs)
  202. self.chat: Chat = Chat(self, provider)
  203. self.images: Images = Images(self, image_provider)
  204. class Completions:
  205. def __init__(self, client: Client, provider: Optional[ProviderType] = None):
  206. self.client: Client = client
  207. self.provider: ProviderType = provider
  208. def create(
  209. self,
  210. messages: Messages,
  211. model: str,
  212. provider: Optional[ProviderType] = None,
  213. stream: Optional[bool] = False,
  214. proxy: Optional[str] = None,
  215. image: Optional[ImageType] = None,
  216. image_name: Optional[str] = None,
  217. response_format: Optional[dict] = None,
  218. max_tokens: Optional[int] = None,
  219. stop: Optional[Union[list[str], str]] = None,
  220. api_key: Optional[str] = None,
  221. ignore_working: Optional[bool] = False,
  222. ignore_stream: Optional[bool] = False,
  223. **kwargs
  224. ) -> ChatCompletion:
  225. if image is not None:
  226. kwargs["images"] = [(image, image_name)]
  227. model, provider = get_model_and_provider(
  228. model,
  229. self.provider if provider is None else provider,
  230. stream,
  231. ignore_working,
  232. ignore_stream,
  233. has_images="images" in kwargs
  234. )
  235. stop = [stop] if isinstance(stop, str) else stop
  236. if ignore_stream:
  237. kwargs["ignore_stream"] = True
  238. response = iter_run_tools(
  239. provider.get_create_function(),
  240. model,
  241. messages,
  242. stream=stream,
  243. **filter_none(
  244. proxy=self.client.proxy if proxy is None else proxy,
  245. max_tokens=max_tokens,
  246. stop=stop,
  247. api_key=self.client.api_key if api_key is None else api_key
  248. ),
  249. **kwargs
  250. )
  251. response = iter_response(response, stream, response_format, max_tokens, stop)
  252. response = iter_append_model_and_provider(response, model, provider)
  253. if stream:
  254. return response
  255. else:
  256. return next(response)
  257. def stream(
  258. self,
  259. messages: Messages,
  260. model: str,
  261. **kwargs
  262. ) -> IterResponse:
  263. return self.create(messages, model, stream=True, **kwargs)
  264. class Chat:
  265. completions: Completions
  266. def __init__(self, client: Client, provider: Optional[ProviderType] = None):
  267. self.completions = Completions(client, provider)
  268. class Images:
  269. def __init__(self, client: Client, provider: Optional[ProviderType] = None):
  270. self.client: Client = client
  271. self.provider: Optional[ProviderType] = provider
  272. self.models: ImageModels = ImageModels(client)
  273. def generate(
  274. self,
  275. prompt: str,
  276. model: str = None,
  277. provider: Optional[ProviderType] = None,
  278. response_format: Optional[str] = None,
  279. proxy: Optional[str] = None,
  280. **kwargs
  281. ) -> ImagesResponse:
  282. """
  283. Synchronous generate method that runs the async_generate method in an event loop.
  284. """
  285. return asyncio.run(self.async_generate(prompt, model, provider, response_format, proxy, **kwargs))
  286. async def get_provider_handler(self, model: Optional[str], provider: Optional[ImageProvider], default: ImageProvider) -> ImageProvider:
  287. if provider is None:
  288. provider_handler = self.provider
  289. if provider_handler is None:
  290. provider_handler = self.models.get(model, default)
  291. elif isinstance(provider, str):
  292. provider_handler = convert_to_provider(provider)
  293. else:
  294. provider_handler = provider
  295. if provider_handler is None:
  296. return default
  297. return provider_handler
  298. async def async_generate(
  299. self,
  300. prompt: str,
  301. model: Optional[str] = None,
  302. provider: Optional[ProviderType] = None,
  303. response_format: Optional[str] = None,
  304. proxy: Optional[str] = None,
  305. **kwargs
  306. ) -> ImagesResponse:
  307. provider_handler = await self.get_provider_handler(model, provider, BingCreateImages)
  308. provider_name = provider_handler.__name__ if hasattr(provider_handler, "__name__") else type(provider_handler).__name__
  309. if proxy is None:
  310. proxy = self.client.proxy
  311. error = None
  312. response = None
  313. if isinstance(provider_handler, IterListProvider):
  314. for provider in provider_handler.providers:
  315. try:
  316. response = await self._generate_image_response(provider, provider.__name__, model, prompt, **kwargs)
  317. if response is not None:
  318. provider_name = provider.__name__
  319. break
  320. except Exception as e:
  321. error = e
  322. debug.log(f"Image provider {provider.__name__}: {e}")
  323. else:
  324. response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)
  325. if isinstance(response, ImageResponse):
  326. return await self._process_image_response(
  327. response,
  328. model,
  329. provider_name,
  330. response_format,
  331. proxy
  332. )
  333. if response is None:
  334. if error is not None:
  335. raise error
  336. raise NoImageResponseError(f"No image response from {provider_name}")
  337. raise NoImageResponseError(f"Unexpected response type: {type(response)}")
  338. async def _generate_image_response(
  339. self,
  340. provider_handler,
  341. provider_name,
  342. model: str,
  343. prompt: str,
  344. prompt_prefix: str = "Generate a image: ",
  345. **kwargs
  346. ) -> ImageResponse:
  347. messages = [{"role": "user", "content": f"{prompt_prefix}{prompt}"}]
  348. response = None
  349. if hasattr(provider_handler, "create_async_generator"):
  350. async for item in provider_handler.create_async_generator(
  351. model,
  352. messages,
  353. stream=True,
  354. prompt=prompt,
  355. **kwargs
  356. ):
  357. if isinstance(item, ImageResponse):
  358. response = item
  359. break
  360. elif hasattr(provider_handler, "create_completion"):
  361. for item in provider_handler.create_completion(
  362. model,
  363. messages,
  364. True,
  365. prompt=prompt,
  366. **kwargs
  367. ):
  368. if isinstance(item, ImageResponse):
  369. response = item
  370. break
  371. else:
  372. raise ValueError(f"Provider {provider_name} does not support image generation")
  373. return response
  374. def create_variation(
  375. self,
  376. image: ImageType,
  377. model: str = None,
  378. provider: Optional[ProviderType] = None,
  379. response_format: Optional[str] = None,
  380. **kwargs
  381. ) -> ImagesResponse:
  382. return asyncio.run(self.async_create_variation(
  383. image, model, provider, response_format, **kwargs
  384. ))
  385. async def async_create_variation(
  386. self,
  387. image: ImageType,
  388. model: Optional[str] = None,
  389. provider: Optional[ProviderType] = None,
  390. response_format: Optional[str] = None,
  391. proxy: Optional[str] = None,
  392. **kwargs
  393. ) -> ImagesResponse:
  394. provider_handler = await self.get_provider_handler(model, provider, OpenaiAccount)
  395. provider_name = provider_handler.__name__ if hasattr(provider_handler, "__name__") else type(provider_handler).__name__
  396. if proxy is None:
  397. proxy = self.client.proxy
  398. prompt = "create a variation of this image"
  399. if image is not None:
  400. kwargs["images"] = [(image, None)]
  401. error = None
  402. response = None
  403. if isinstance(provider_handler, IterListProvider):
  404. for provider in provider_handler.providers:
  405. try:
  406. response = await self._generate_image_response(provider, provider.__name__, model, prompt, **kwargs)
  407. if response is not None:
  408. provider_name = provider.__name__
  409. break
  410. except Exception as e:
  411. error = e
  412. debug.log(f"Image provider {provider.__name__}: {e}")
  413. else:
  414. response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)
  415. if isinstance(response, ImageResponse):
  416. return await self._process_image_response(response, model, provider_name, response_format, proxy)
  417. if response is None:
  418. if error is not None:
  419. raise error
  420. raise NoImageResponseError(f"No image response from {provider_name}")
  421. raise NoImageResponseError(f"Unexpected response type: {type(response)}")
  422. async def _process_image_response(
  423. self,
  424. response: ImageResponse,
  425. model: str,
  426. provider: str,
  427. response_format: Optional[str] = None,
  428. proxy: str = None
  429. ) -> ImagesResponse:
  430. if response_format == "url":
  431. # Return original URLs without saving locally
  432. images = [Image.model_construct(url=image, revised_prompt=response.alt) for image in response.get_list()]
  433. elif response_format == "b64_json":
  434. # Convert URLs directly to base64 without saving
  435. async def get_b64_from_url(url: str) -> Image:
  436. async with aiohttp.ClientSession(cookies=response.get("cookies")) as session:
  437. async with session.get(url, proxy=proxy) as resp:
  438. if resp.status == 200:
  439. image_data = await resp.read()
  440. b64_data = base64.b64encode(image_data).decode()
  441. return Image.model_construct(b64_json=b64_data, revised_prompt=response.alt)
  442. images = await asyncio.gather(*[get_b64_from_url(image) for image in response.get_list()])
  443. else:
  444. # Save locally for None (default) case
  445. images = await copy_images(response.get_list(), response.get("cookies"), proxy)
  446. images = [Image.model_construct(url=f"/images/{os.path.basename(image)}", revised_prompt=response.alt) for image in images]
  447. return ImagesResponse.model_construct(
  448. created=int(time.time()),
  449. data=images,
  450. model=model,
  451. provider=provider
  452. )
  453. class AsyncClient(BaseClient):
  454. def __init__(
  455. self,
  456. provider: Optional[ProviderType] = None,
  457. image_provider: Optional[ImageProvider] = None,
  458. **kwargs
  459. ) -> None:
  460. super().__init__(**kwargs)
  461. self.chat: AsyncChat = AsyncChat(self, provider)
  462. self.images: AsyncImages = AsyncImages(self, image_provider)
  463. class AsyncChat:
  464. completions: AsyncCompletions
  465. def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
  466. self.completions = AsyncCompletions(client, provider)
  467. class AsyncCompletions:
  468. def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
  469. self.client: AsyncClient = client
  470. self.provider: ProviderType = provider
  471. def create(
  472. self,
  473. messages: Messages,
  474. model: str,
  475. provider: Optional[ProviderType] = None,
  476. stream: Optional[bool] = False,
  477. proxy: Optional[str] = None,
  478. image: Optional[ImageType] = None,
  479. image_name: Optional[str] = None,
  480. response_format: Optional[dict] = None,
  481. max_tokens: Optional[int] = None,
  482. stop: Optional[Union[list[str], str]] = None,
  483. api_key: Optional[str] = None,
  484. ignore_working: Optional[bool] = False,
  485. ignore_stream: Optional[bool] = False,
  486. **kwargs
  487. ) -> Awaitable[ChatCompletion]:
  488. if image is not None:
  489. kwargs["images"] = [(image, image_name)]
  490. model, provider = get_model_and_provider(
  491. model,
  492. self.provider if provider is None else provider,
  493. stream,
  494. ignore_working,
  495. ignore_stream,
  496. has_images="images" in kwargs,
  497. )
  498. stop = [stop] if isinstance(stop, str) else stop
  499. if ignore_stream:
  500. kwargs["ignore_stream"] = True
  501. response = async_iter_run_tools(
  502. provider,
  503. model,
  504. messages,
  505. stream=stream,
  506. **filter_none(
  507. proxy=self.client.proxy if proxy is None else proxy,
  508. max_tokens=max_tokens,
  509. stop=stop,
  510. api_key=self.client.api_key if api_key is None else api_key
  511. ),
  512. **kwargs
  513. )
  514. response = async_iter_response(response, stream, response_format, max_tokens, stop)
  515. response = async_iter_append_model_and_provider(response, model, provider)
  516. if stream:
  517. return response
  518. else:
  519. return anext(response)
  520. def stream(
  521. self,
  522. messages: Messages,
  523. model: str,
  524. **kwargs
  525. ) -> AsyncIterator[ChatCompletionChunk, BaseConversation]:
  526. return self.create(messages, model, stream=True, **kwargs)
  527. class AsyncImages(Images):
  528. def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
  529. self.client: AsyncClient = client
  530. self.provider: Optional[ProviderType] = provider
  531. self.models: ImageModels = ImageModels(client)
  532. async def generate(
  533. self,
  534. prompt: str,
  535. model: Optional[str] = None,
  536. provider: Optional[ProviderType] = None,
  537. response_format: Optional[str] = None,
  538. **kwargs
  539. ) -> ImagesResponse:
  540. return await self.async_generate(prompt, model, provider, response_format, **kwargs)
  541. async def create_variation(
  542. self,
  543. image: ImageType,
  544. model: str = None,
  545. provider: ProviderType = None,
  546. response_format: Optional[str] = None,
  547. **kwargs
  548. ) -> ImagesResponse:
  549. return await self.async_create_variation(
  550. image, model, provider, response_format, **kwargs
  551. )