base_provider.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. from abc import ABC, abstractmethod
  2. from .typing import Messages, CreateResult, Union
  3. class BaseProvider(ABC):
  4. url: str
  5. working: bool = False
  6. needs_auth: bool = False
  7. supports_stream: bool = False
  8. supports_gpt_35_turbo: bool = False
  9. supports_gpt_4: bool = False
  10. supports_message_history: bool = False
  11. params: str
  12. @classmethod
  13. @abstractmethod
  14. def create_completion(
  15. cls,
  16. model: str,
  17. messages: Messages,
  18. stream: bool,
  19. **kwargs
  20. ) -> CreateResult:
  21. raise NotImplementedError()
  22. @classmethod
  23. @abstractmethod
  24. async def create_async(
  25. cls,
  26. model: str,
  27. messages: Messages,
  28. **kwargs
  29. ) -> str:
  30. raise NotImplementedError()
  31. @classmethod
  32. def get_dict(cls):
  33. return {'name': cls.__name__, 'url': cls.url}
  34. class BaseRetryProvider(BaseProvider):
  35. __name__: str = "RetryProvider"
  36. supports_stream: bool = True
  37. def __init__(
  38. self,
  39. providers: list[type[BaseProvider]],
  40. shuffle: bool = True
  41. ) -> None:
  42. self.providers: list[type[BaseProvider]] = providers
  43. self.shuffle: bool = shuffle
  44. self.working: bool = True
  45. self.exceptions: dict[str, Exception] = {}
  46. self.last_provider: type[BaseProvider] = None
  47. ProviderType = Union[type[BaseProvider], BaseRetryProvider]