base_provider.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. from abc import ABC, abstractmethod
  2. from typing import Union, List, Dict, Type
  3. from .typing import Messages, CreateResult
  4. class BaseProvider(ABC):
  5. """
  6. Abstract base class for a provider.
  7. Attributes:
  8. url (str): URL of the provider.
  9. working (bool): Indicates if the provider is currently working.
  10. needs_auth (bool): Indicates if the provider needs authentication.
  11. supports_stream (bool): Indicates if the provider supports streaming.
  12. supports_gpt_35_turbo (bool): Indicates if the provider supports GPT-3.5 Turbo.
  13. supports_gpt_4 (bool): Indicates if the provider supports GPT-4.
  14. supports_message_history (bool): Indicates if the provider supports message history.
  15. params (str): List parameters for the provider.
  16. """
  17. url: str = None
  18. working: bool = False
  19. needs_auth: bool = False
  20. supports_stream: bool = False
  21. supports_gpt_35_turbo: bool = False
  22. supports_gpt_4: bool = False
  23. supports_message_history: bool = False
  24. params: str
  25. @classmethod
  26. @abstractmethod
  27. def create_completion(
  28. cls,
  29. model: str,
  30. messages: Messages,
  31. stream: bool,
  32. **kwargs
  33. ) -> CreateResult:
  34. """
  35. Create a completion with the given parameters.
  36. Args:
  37. model (str): The model to use.
  38. messages (Messages): The messages to process.
  39. stream (bool): Whether to use streaming.
  40. **kwargs: Additional keyword arguments.
  41. Returns:
  42. CreateResult: The result of the creation process.
  43. """
  44. raise NotImplementedError()
  45. @classmethod
  46. @abstractmethod
  47. async def create_async(
  48. cls,
  49. model: str,
  50. messages: Messages,
  51. **kwargs
  52. ) -> str:
  53. """
  54. Asynchronously create a completion with the given parameters.
  55. Args:
  56. model (str): The model to use.
  57. messages (Messages): The messages to process.
  58. **kwargs: Additional keyword arguments.
  59. Returns:
  60. str: The result of the creation process.
  61. """
  62. raise NotImplementedError()
  63. @classmethod
  64. def get_dict(cls) -> Dict[str, str]:
  65. """
  66. Get a dictionary representation of the provider.
  67. Returns:
  68. Dict[str, str]: A dictionary with provider's details.
  69. """
  70. return {'name': cls.__name__, 'url': cls.url}
  71. class BaseRetryProvider(BaseProvider):
  72. """
  73. Base class for a provider that implements retry logic.
  74. Attributes:
  75. providers (List[Type[BaseProvider]]): List of providers to use for retries.
  76. shuffle (bool): Whether to shuffle the providers list.
  77. exceptions (Dict[str, Exception]): Dictionary of exceptions encountered.
  78. last_provider (Type[BaseProvider]): The last provider used.
  79. """
  80. __name__: str = "RetryProvider"
  81. supports_stream: bool = True
  82. def __init__(
  83. self,
  84. providers: List[Type[BaseProvider]],
  85. shuffle: bool = True
  86. ) -> None:
  87. """
  88. Initialize the BaseRetryProvider.
  89. Args:
  90. providers (List[Type[BaseProvider]]): List of providers to use.
  91. shuffle (bool): Whether to shuffle the providers list.
  92. """
  93. self.providers = providers
  94. self.shuffle = shuffle
  95. self.working = True
  96. self.exceptions: Dict[str, Exception] = {}
  97. self.last_provider: Type[BaseProvider] = None
  98. ProviderType = Union[Type[BaseProvider], BaseRetryProvider]