retry_provider.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. from __future__ import annotations
  2. import asyncio
  3. import random
  4. from ..typing import CreateResult, Messages
  5. from ..base_provider import BaseRetryProvider
  6. from .. import debug
  7. from ..errors import RetryProviderError, RetryNoProviderError
  8. class RetryProvider(BaseRetryProvider):
  9. def create_completion(
  10. self,
  11. model: str,
  12. messages: Messages,
  13. stream: bool = False,
  14. **kwargs
  15. ) -> CreateResult:
  16. if stream:
  17. providers = [provider for provider in self.providers if provider.supports_stream]
  18. else:
  19. providers = self.providers
  20. if self.shuffle:
  21. random.shuffle(providers)
  22. self.exceptions = {}
  23. started: bool = False
  24. for provider in providers:
  25. self.last_provider = provider
  26. try:
  27. if debug.logging:
  28. print(f"Using {provider.__name__} provider")
  29. for token in provider.create_completion(model, messages, stream, **kwargs):
  30. yield token
  31. started = True
  32. if started:
  33. return
  34. except Exception as e:
  35. self.exceptions[provider.__name__] = e
  36. if debug.logging:
  37. print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
  38. if started:
  39. raise e
  40. self.raise_exceptions()
  41. async def create_async(
  42. self,
  43. model: str,
  44. messages: Messages,
  45. **kwargs
  46. ) -> str:
  47. providers = self.providers
  48. if self.shuffle:
  49. random.shuffle(providers)
  50. self.exceptions = {}
  51. for provider in providers:
  52. self.last_provider = provider
  53. try:
  54. return await asyncio.wait_for(
  55. provider.create_async(model, messages, **kwargs),
  56. timeout=kwargs.get("timeout", 60)
  57. )
  58. except Exception as e:
  59. self.exceptions[provider.__name__] = e
  60. if debug.logging:
  61. print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
  62. self.raise_exceptions()
  63. def raise_exceptions(self) -> None:
  64. if self.exceptions:
  65. raise RetryProviderError("RetryProvider failed:\n" + "\n".join([
  66. f"{p}: {exception.__class__.__name__}: {exception}" for p, exception in self.exceptions.items()
  67. ]))
  68. raise RetryNoProviderError("No provider found")