retry_provider.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. from __future__ import annotations
  2. import asyncio
  3. import random
  4. from typing import List, Type, Dict
  5. from ..typing import CreateResult, Messages
  6. from .base_provider import BaseProvider, AsyncProvider
  7. from .. import debug
  8. from ..errors import RetryProviderError, RetryNoProviderError
  9. class RetryProvider(AsyncProvider):
  10. __name__: str = "RetryProvider"
  11. supports_stream: bool = True
  12. def __init__(
  13. self,
  14. providers: List[Type[BaseProvider]],
  15. shuffle: bool = True
  16. ) -> None:
  17. self.providers: List[Type[BaseProvider]] = providers
  18. self.shuffle: bool = shuffle
  19. self.working = True
  20. def create_completion(
  21. self,
  22. model: str,
  23. messages: Messages,
  24. stream: bool = False,
  25. **kwargs
  26. ) -> CreateResult:
  27. if stream:
  28. providers = [provider for provider in self.providers if provider.supports_stream]
  29. else:
  30. providers = self.providers
  31. if self.shuffle:
  32. random.shuffle(providers)
  33. self.exceptions: Dict[str, Exception] = {}
  34. started: bool = False
  35. for provider in providers:
  36. try:
  37. if debug.logging:
  38. print(f"Using {provider.__name__} provider")
  39. for token in provider.create_completion(model, messages, stream, **kwargs):
  40. yield token
  41. started = True
  42. if started:
  43. return
  44. except Exception as e:
  45. self.exceptions[provider.__name__] = e
  46. if debug.logging:
  47. print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
  48. if started:
  49. raise e
  50. self.raise_exceptions()
  51. async def create_async(
  52. self,
  53. model: str,
  54. messages: Messages,
  55. **kwargs
  56. ) -> str:
  57. providers = self.providers
  58. if self.shuffle:
  59. random.shuffle(providers)
  60. self.exceptions: Dict[str, Exception] = {}
  61. for provider in providers:
  62. try:
  63. return await asyncio.wait_for(
  64. provider.create_async(model, messages, **kwargs),
  65. timeout=kwargs.get("timeout", 60)
  66. )
  67. except Exception as e:
  68. self.exceptions[provider.__name__] = e
  69. if debug.logging:
  70. print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
  71. self.raise_exceptions()
  72. def raise_exceptions(self) -> None:
  73. if self.exceptions:
  74. raise RetryProviderError("RetryProvider failed:\n" + "\n".join([
  75. f"{p}: {exception.__class__.__name__}: {exception}" for p, exception in self.exceptions.items()
  76. ]))
  77. raise RetryNoProviderError("No provider found")