retry_provider.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. from __future__ import annotations
  2. import asyncio
  3. import random
  4. from ..typing import Type, List, CreateResult, Messages, AsyncResult
  5. from .types import BaseProvider, BaseRetryProvider, ProviderType
  6. from .. import debug
  7. from ..errors import RetryProviderError, RetryNoProviderError
  8. DEFAULT_TIMEOUT = 60
  9. class IterListProvider(BaseRetryProvider):
  10. def __init__(
  11. self,
  12. providers: List[Type[BaseProvider]],
  13. shuffle: bool = True
  14. ) -> None:
  15. """
  16. Initialize the BaseRetryProvider.
  17. Args:
  18. providers (List[Type[BaseProvider]]): List of providers to use.
  19. shuffle (bool): Whether to shuffle the providers list.
  20. single_provider_retry (bool): Whether to retry a single provider if it fails.
  21. max_retries (int): Maximum number of retries for a single provider.
  22. """
  23. self.providers = providers
  24. self.shuffle = shuffle
  25. self.working = True
  26. self.last_provider: Type[BaseProvider] = None
  27. def create_completion(
  28. self,
  29. model: str,
  30. messages: Messages,
  31. stream: bool = False,
  32. ignore_stream: bool = False,
  33. ignored: list[str] = [],
  34. **kwargs,
  35. ) -> CreateResult:
  36. """
  37. Create a completion using available providers, with an option to stream the response.
  38. Args:
  39. model (str): The model to be used for completion.
  40. messages (Messages): The messages to be used for generating completion.
  41. stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False.
  42. Yields:
  43. CreateResult: Tokens or results from the completion.
  44. Raises:
  45. Exception: Any exception encountered during the completion process.
  46. """
  47. exceptions = {}
  48. started: bool = False
  49. for provider in self.get_providers(stream and not ignore_stream, ignored):
  50. self.last_provider = provider
  51. debug.log(f"Using {provider.__name__} provider")
  52. try:
  53. for chunk in provider.create_completion(model, messages, stream, **kwargs):
  54. if chunk:
  55. yield chunk
  56. started = True
  57. if started:
  58. return
  59. except Exception as e:
  60. exceptions[provider.__name__] = e
  61. debug.log(f"{provider.__name__}: {e.__class__.__name__}: {e}")
  62. if started:
  63. raise e
  64. raise_exceptions(exceptions)
  65. async def create_async(
  66. self,
  67. model: str,
  68. messages: Messages,
  69. ignored: list[str] = [],
  70. **kwargs,
  71. ) -> str:
  72. """
  73. Asynchronously create a completion using available providers.
  74. Args:
  75. model (str): The model to be used for completion.
  76. messages (Messages): The messages to be used for generating completion.
  77. Returns:
  78. str: The result of the asynchronous completion.
  79. Raises:
  80. Exception: Any exception encountered during the asynchronous completion process.
  81. """
  82. exceptions = {}
  83. for provider in self.get_providers(False, ignored):
  84. self.last_provider = provider
  85. debug.log(f"Using {provider.__name__} provider")
  86. try:
  87. chunk = await asyncio.wait_for(
  88. provider.create_async(model, messages, **kwargs),
  89. timeout=kwargs.get("timeout", DEFAULT_TIMEOUT),
  90. )
  91. if chunk:
  92. return chunk
  93. except Exception as e:
  94. exceptions[provider.__name__] = e
  95. debug.log(f"{provider.__name__}: {e.__class__.__name__}: {e}")
  96. raise_exceptions(exceptions)
  97. async def create_async_generator(
  98. self,
  99. model: str,
  100. messages: Messages,
  101. stream: bool = True,
  102. ignore_stream: bool = False,
  103. ignored: list[str] = [],
  104. **kwargs
  105. ) -> AsyncResult:
  106. exceptions = {}
  107. started: bool = False
  108. for provider in self.get_providers(stream and not ignore_stream, ignored):
  109. self.last_provider = provider
  110. debug.log(f"Using {provider.__name__} provider")
  111. try:
  112. if not stream:
  113. chunk = await asyncio.wait_for(
  114. provider.create_async(model, messages, **kwargs),
  115. timeout=kwargs.get("timeout", DEFAULT_TIMEOUT),
  116. )
  117. if chunk:
  118. yield chunk
  119. started = True
  120. elif hasattr(provider, "create_async_generator"):
  121. async for chunk in provider.create_async_generator(model, messages, stream=stream, **kwargs):
  122. if chunk:
  123. yield chunk
  124. started = True
  125. else:
  126. for token in provider.create_completion(model, messages, stream, **kwargs):
  127. yield token
  128. started = True
  129. if started:
  130. return
  131. except Exception as e:
  132. exceptions[provider.__name__] = e
  133. debug.log(f"{provider.__name__}: {e.__class__.__name__}: {e}")
  134. if started:
  135. raise e
  136. raise_exceptions(exceptions)
  137. def get_providers(self, stream: bool, ignored: list[str]) -> list[ProviderType]:
  138. providers = [p for p in self.providers if (p.supports_stream or not stream) and p.__name__ not in ignored]
  139. if self.shuffle:
  140. random.shuffle(providers)
  141. return providers
  142. class RetryProvider(IterListProvider):
  143. def __init__(
  144. self,
  145. providers: List[Type[BaseProvider]],
  146. shuffle: bool = True,
  147. single_provider_retry: bool = False,
  148. max_retries: int = 3,
  149. ) -> None:
  150. """
  151. Initialize the BaseRetryProvider.
  152. Args:
  153. providers (List[Type[BaseProvider]]): List of providers to use.
  154. shuffle (bool): Whether to shuffle the providers list.
  155. single_provider_retry (bool): Whether to retry a single provider if it fails.
  156. max_retries (int): Maximum number of retries for a single provider.
  157. """
  158. super().__init__(providers, shuffle)
  159. self.single_provider_retry = single_provider_retry
  160. self.max_retries = max_retries
  161. def create_completion(
  162. self,
  163. model: str,
  164. messages: Messages,
  165. stream: bool = False,
  166. **kwargs,
  167. ) -> CreateResult:
  168. """
  169. Create a completion using available providers, with an option to stream the response.
  170. Args:
  171. model (str): The model to be used for completion.
  172. messages (Messages): The messages to be used for generating completion.
  173. stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False.
  174. Yields:
  175. CreateResult: Tokens or results from the completion.
  176. Raises:
  177. Exception: Any exception encountered during the completion process.
  178. """
  179. if self.single_provider_retry:
  180. exceptions = {}
  181. started: bool = False
  182. provider = self.providers[0]
  183. self.last_provider = provider
  184. for attempt in range(self.max_retries):
  185. try:
  186. if debug.logging:
  187. print(f"Using {provider.__name__} provider (attempt {attempt + 1})")
  188. for token in provider.create_completion(model, messages, stream, **kwargs):
  189. yield token
  190. started = True
  191. if started:
  192. return
  193. except Exception as e:
  194. exceptions[provider.__name__] = e
  195. if debug.logging:
  196. print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
  197. if started:
  198. raise e
  199. raise_exceptions(exceptions)
  200. else:
  201. yield from super().create_completion(model, messages, stream, **kwargs)
  202. async def create_async(
  203. self,
  204. model: str,
  205. messages: Messages,
  206. **kwargs,
  207. ) -> str:
  208. """
  209. Asynchronously create a completion using available providers.
  210. Args:
  211. model (str): The model to be used for completion.
  212. messages (Messages): The messages to be used for generating completion.
  213. Returns:
  214. str: The result of the asynchronous completion.
  215. Raises:
  216. Exception: Any exception encountered during the asynchronous completion process.
  217. """
  218. exceptions = {}
  219. if self.single_provider_retry:
  220. provider = self.providers[0]
  221. self.last_provider = provider
  222. for attempt in range(self.max_retries):
  223. try:
  224. if debug.logging:
  225. print(f"Using {provider.__name__} provider (attempt {attempt + 1})")
  226. return await asyncio.wait_for(
  227. provider.create_async(model, messages, **kwargs),
  228. timeout=kwargs.get("timeout", 60),
  229. )
  230. except Exception as e:
  231. exceptions[provider.__name__] = e
  232. if debug.logging:
  233. print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
  234. raise_exceptions(exceptions)
  235. else:
  236. return await super().create_async(model, messages, **kwargs)
  237. async def create_async_generator(
  238. self,
  239. model: str,
  240. messages: Messages,
  241. stream: bool = True,
  242. **kwargs
  243. ) -> AsyncResult:
  244. exceptions = {}
  245. started = False
  246. if self.single_provider_retry:
  247. provider = self.providers[0]
  248. self.last_provider = provider
  249. for attempt in range(self.max_retries):
  250. try:
  251. debug.log(f"Using {provider.__name__} provider (attempt {attempt + 1})")
  252. if not stream:
  253. chunk = await asyncio.wait_for(
  254. provider.create_async(model, messages, **kwargs),
  255. timeout=kwargs.get("timeout", DEFAULT_TIMEOUT),
  256. )
  257. if chunk:
  258. yield chunk
  259. started = True
  260. elif hasattr(provider, "create_async_generator"):
  261. async for chunk in provider.create_async_generator(model, messages, stream=stream, **kwargs):
  262. if chunk:
  263. yield chunk
  264. started = True
  265. else:
  266. for token in provider.create_completion(model, messages, stream, **kwargs):
  267. yield token
  268. started = True
  269. if started:
  270. return
  271. except Exception as e:
  272. exceptions[provider.__name__] = e
  273. if debug.logging:
  274. print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
  275. raise_exceptions(exceptions)
  276. else:
  277. async for chunk in super().create_async_generator(model, messages, stream, **kwargs):
  278. yield chunk
  279. def raise_exceptions(exceptions: dict) -> None:
  280. """
  281. Raise a combined exception if any occurred during retries.
  282. Raises:
  283. RetryProviderError: If any provider encountered an exception.
  284. RetryNoProviderError: If no provider is found.
  285. """
  286. if exceptions:
  287. raise RetryProviderError("RetryProvider failed:\n" + "\n".join([
  288. f"{p}: {type(exception).__name__}: {exception}" for p, exception in exceptions.items()
  289. ]))
  290. raise RetryNoProviderError("No provider found")