base_provider.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. from __future__ import annotations
  2. from abc import ABC, abstractmethod
  3. from typing import Union, List, Dict, Type
  4. from .typing import Messages, CreateResult
  5. class BaseProvider(ABC):
  6. """
  7. Abstract base class for a provider.
  8. Attributes:
  9. url (str): URL of the provider.
  10. working (bool): Indicates if the provider is currently working.
  11. needs_auth (bool): Indicates if the provider needs authentication.
  12. supports_stream (bool): Indicates if the provider supports streaming.
  13. supports_gpt_35_turbo (bool): Indicates if the provider supports GPT-3.5 Turbo.
  14. supports_gpt_4 (bool): Indicates if the provider supports GPT-4.
  15. supports_message_history (bool): Indicates if the provider supports message history.
  16. params (str): List parameters for the provider.
  17. """
  18. url: str = None
  19. working: bool = False
  20. needs_auth: bool = False
  21. supports_stream: bool = False
  22. supports_gpt_35_turbo: bool = False
  23. supports_gpt_4: bool = False
  24. supports_message_history: bool = False
  25. params: str
  26. @classmethod
  27. @abstractmethod
  28. def create_completion(
  29. cls,
  30. model: str,
  31. messages: Messages,
  32. stream: bool,
  33. **kwargs
  34. ) -> CreateResult:
  35. """
  36. Create a completion with the given parameters.
  37. Args:
  38. model (str): The model to use.
  39. messages (Messages): The messages to process.
  40. stream (bool): Whether to use streaming.
  41. **kwargs: Additional keyword arguments.
  42. Returns:
  43. CreateResult: The result of the creation process.
  44. """
  45. raise NotImplementedError()
  46. @classmethod
  47. @abstractmethod
  48. async def create_async(
  49. cls,
  50. model: str,
  51. messages: Messages,
  52. **kwargs
  53. ) -> str:
  54. """
  55. Asynchronously create a completion with the given parameters.
  56. Args:
  57. model (str): The model to use.
  58. messages (Messages): The messages to process.
  59. **kwargs: Additional keyword arguments.
  60. Returns:
  61. str: The result of the creation process.
  62. """
  63. raise NotImplementedError()
  64. @classmethod
  65. def get_dict(cls) -> Dict[str, str]:
  66. """
  67. Get a dictionary representation of the provider.
  68. Returns:
  69. Dict[str, str]: A dictionary with provider's details.
  70. """
  71. return {'name': cls.__name__, 'url': cls.url}
  72. class BaseRetryProvider(BaseProvider):
  73. """
  74. Base class for a provider that implements retry logic.
  75. Attributes:
  76. providers (List[Type[BaseProvider]]): List of providers to use for retries.
  77. shuffle (bool): Whether to shuffle the providers list.
  78. exceptions (Dict[str, Exception]): Dictionary of exceptions encountered.
  79. last_provider (Type[BaseProvider]): The last provider used.
  80. """
  81. __name__: str = "RetryProvider"
  82. supports_stream: bool = True
  83. def __init__(
  84. self,
  85. providers: List[Type[BaseProvider]],
  86. shuffle: bool = True
  87. ) -> None:
  88. """
  89. Initialize the BaseRetryProvider.
  90. Args:
  91. providers (List[Type[BaseProvider]]): List of providers to use.
  92. shuffle (bool): Whether to shuffle the providers list.
  93. """
  94. self.providers = providers
  95. self.shuffle = shuffle
  96. self.working = True
  97. self.exceptions: Dict[str, Exception] = {}
  98. self.last_provider: Type[BaseProvider] = None
  99. ProviderType = Union[Type[BaseProvider], BaseRetryProvider]