types.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. from __future__ import annotations
  2. from abc import ABC, abstractmethod
  3. from typing import Union, Dict, Type
  4. from ..typing import Messages, CreateResult
  5. class BaseProvider(ABC):
  6. """
  7. Abstract base class for a provider.
  8. Attributes:
  9. url (str): URL of the provider.
  10. working (bool): Indicates if the provider is currently working.
  11. needs_auth (bool): Indicates if the provider needs authentication.
  12. supports_stream (bool): Indicates if the provider supports streaming.
  13. supports_gpt_35_turbo (bool): Indicates if the provider supports GPT-3.5 Turbo.
  14. supports_gpt_4 (bool): Indicates if the provider supports GPT-4.
  15. supports_message_history (bool): Indicates if the provider supports message history.
  16. params (str): List parameters for the provider.
  17. """
  18. url: str = None
  19. working: bool = False
  20. needs_auth: bool = False
  21. supports_stream: bool = False
  22. supports_gpt_35_turbo: bool = False
  23. supports_gpt_4: bool = False
  24. supports_message_history: bool = False
  25. supports_system_message: bool = False
  26. params: str
  27. @classmethod
  28. @abstractmethod
  29. def create_completion(
  30. cls,
  31. model: str,
  32. messages: Messages,
  33. stream: bool,
  34. **kwargs
  35. ) -> CreateResult:
  36. """
  37. Create a completion with the given parameters.
  38. Args:
  39. model (str): The model to use.
  40. messages (Messages): The messages to process.
  41. stream (bool): Whether to use streaming.
  42. **kwargs: Additional keyword arguments.
  43. Returns:
  44. CreateResult: The result of the creation process.
  45. """
  46. raise NotImplementedError()
  47. @classmethod
  48. @abstractmethod
  49. async def create_async(
  50. cls,
  51. model: str,
  52. messages: Messages,
  53. **kwargs
  54. ) -> str:
  55. """
  56. Asynchronously create a completion with the given parameters.
  57. Args:
  58. model (str): The model to use.
  59. messages (Messages): The messages to process.
  60. **kwargs: Additional keyword arguments.
  61. Returns:
  62. str: The result of the creation process.
  63. """
  64. raise NotImplementedError()
  65. @classmethod
  66. def get_dict(cls) -> Dict[str, str]:
  67. """
  68. Get a dictionary representation of the provider.
  69. Returns:
  70. Dict[str, str]: A dictionary with provider's details.
  71. """
  72. return {'name': cls.__name__, 'url': cls.url}
  73. class BaseRetryProvider(BaseProvider):
  74. """
  75. Base class for a provider that implements retry logic.
  76. Attributes:
  77. providers (List[Type[BaseProvider]]): List of providers to use for retries.
  78. shuffle (bool): Whether to shuffle the providers list.
  79. exceptions (Dict[str, Exception]): Dictionary of exceptions encountered.
  80. last_provider (Type[BaseProvider]): The last provider used.
  81. """
  82. __name__: str = "RetryProvider"
  83. supports_stream: bool = True
  84. ProviderType = Union[Type[BaseProvider], BaseRetryProvider]