retry_provider.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. from __future__ import annotations
  2. import asyncio
  3. import random
  4. from ..typing import Type, List, CreateResult, Messages, Iterator
  5. from .types import BaseProvider, BaseRetryProvider
  6. from .. import debug
  7. from ..errors import RetryProviderError, RetryNoProviderError
  8. class RetryProvider(BaseRetryProvider):
  9. def __init__(
  10. self,
  11. providers: List[Type[BaseProvider]],
  12. shuffle: bool = True
  13. ) -> None:
  14. """
  15. Initialize the BaseRetryProvider.
  16. Args:
  17. providers (List[Type[BaseProvider]]): List of providers to use.
  18. shuffle (bool): Whether to shuffle the providers list.
  19. """
  20. self.providers = providers
  21. self.shuffle = shuffle
  22. self.working = True
  23. self.last_provider: Type[BaseProvider] = None
  24. """
  25. A provider class to handle retries for creating completions with different providers.
  26. Attributes:
  27. providers (list): A list of provider instances.
  28. shuffle (bool): A flag indicating whether to shuffle providers before use.
  29. last_provider (BaseProvider): The last provider that was used.
  30. """
  31. def create_completion(
  32. self,
  33. model: str,
  34. messages: Messages,
  35. stream: bool = False,
  36. **kwargs
  37. ) -> CreateResult:
  38. """
  39. Create a completion using available providers, with an option to stream the response.
  40. Args:
  41. model (str): The model to be used for completion.
  42. messages (Messages): The messages to be used for generating completion.
  43. stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False.
  44. Yields:
  45. CreateResult: Tokens or results from the completion.
  46. Raises:
  47. Exception: Any exception encountered during the completion process.
  48. """
  49. providers = [p for p in self.providers if stream and p.supports_stream] if stream else self.providers
  50. if self.shuffle:
  51. random.shuffle(providers)
  52. exceptions = {}
  53. started: bool = False
  54. for provider in providers:
  55. self.last_provider = provider
  56. try:
  57. if debug.logging:
  58. print(f"Using {provider.__name__} provider")
  59. for token in provider.create_completion(model, messages, stream, **kwargs):
  60. yield token
  61. started = True
  62. if started:
  63. return
  64. except Exception as e:
  65. exceptions[provider.__name__] = e
  66. if debug.logging:
  67. print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
  68. if started:
  69. raise e
  70. raise_exceptions(exceptions)
  71. async def create_async(
  72. self,
  73. model: str,
  74. messages: Messages,
  75. **kwargs
  76. ) -> str:
  77. """
  78. Asynchronously create a completion using available providers.
  79. Args:
  80. model (str): The model to be used for completion.
  81. messages (Messages): The messages to be used for generating completion.
  82. Returns:
  83. str: The result of the asynchronous completion.
  84. Raises:
  85. Exception: Any exception encountered during the asynchronous completion process.
  86. """
  87. providers = self.providers
  88. if self.shuffle:
  89. random.shuffle(providers)
  90. exceptions = {}
  91. for provider in providers:
  92. self.last_provider = provider
  93. try:
  94. return await asyncio.wait_for(
  95. provider.create_async(model, messages, **kwargs),
  96. timeout=kwargs.get("timeout", 60)
  97. )
  98. except Exception as e:
  99. exceptions[provider.__name__] = e
  100. if debug.logging:
  101. print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
  102. raise_exceptions(exceptions)
  103. class IterProvider(BaseRetryProvider):
  104. __name__ = "IterProvider"
  105. def __init__(
  106. self,
  107. providers: List[BaseProvider],
  108. ) -> None:
  109. providers.reverse()
  110. self.providers: List[BaseProvider] = providers
  111. self.working: bool = True
  112. self.last_provider: BaseProvider = None
  113. def create_completion(
  114. self,
  115. model: str,
  116. messages: Messages,
  117. stream: bool = False,
  118. **kwargs
  119. ) -> CreateResult:
  120. exceptions: dict = {}
  121. started: bool = False
  122. for provider in self.iter_providers():
  123. if stream and not provider.supports_stream:
  124. continue
  125. try:
  126. for token in provider.create_completion(model, messages, stream, **kwargs):
  127. yield token
  128. started = True
  129. if started:
  130. return
  131. except Exception as e:
  132. exceptions[provider.__name__] = e
  133. if debug.logging:
  134. print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
  135. if started:
  136. raise e
  137. raise_exceptions(exceptions)
  138. async def create_async(
  139. self,
  140. model: str,
  141. messages: Messages,
  142. **kwargs
  143. ) -> str:
  144. exceptions: dict = {}
  145. for provider in self.iter_providers():
  146. try:
  147. return await asyncio.wait_for(
  148. provider.create_async(model, messages, **kwargs),
  149. timeout=kwargs.get("timeout", 60)
  150. )
  151. except Exception as e:
  152. exceptions[provider.__name__] = e
  153. if debug.logging:
  154. print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
  155. raise_exceptions(exceptions)
  156. def iter_providers(self) -> Iterator[BaseProvider]:
  157. used_provider = []
  158. try:
  159. while self.providers:
  160. provider = self.providers.pop()
  161. used_provider.append(provider)
  162. self.last_provider = provider
  163. if debug.logging:
  164. print(f"Using {provider.__name__} provider")
  165. yield provider
  166. finally:
  167. used_provider.reverse()
  168. self.providers = [*used_provider, *self.providers]
  169. def raise_exceptions(exceptions: dict) -> None:
  170. """
  171. Raise a combined exception if any occurred during retries.
  172. Raises:
  173. RetryProviderError: If any provider encountered an exception.
  174. RetryNoProviderError: If no provider is found.
  175. """
  176. if exceptions:
  177. raise RetryProviderError("RetryProvider failed:\n" + "\n".join([
  178. f"{p}: {exception.__class__.__name__}: {exception}" for p, exception in exceptions.items()
  179. ]))
  180. raise RetryNoProviderError("No provider found")