types.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. from __future__ import annotations
  2. from abc import ABC, abstractmethod
  3. from typing import Union, Dict, Type
  4. from ..typing import Messages, CreateResult
  5. class BaseProvider(ABC):
  6. """
  7. Abstract base class for a provider.
  8. Attributes:
  9. url (str): URL of the provider.
  10. working (bool): Indicates if the provider is currently working.
  11. needs_auth (bool): Indicates if the provider needs authentication.
  12. supports_stream (bool): Indicates if the provider supports streaming.
  13. supports_message_history (bool): Indicates if the provider supports message history.
  14. supports_system_message (bool): Indicates if the provider supports system messages.
  15. params (str): List parameters for the provider.
  16. """
  17. url: str = None
  18. working: bool = False
  19. needs_auth: bool = False
  20. supports_stream: bool = False
  21. supports_message_history: bool = False
  22. supports_system_message: bool = False
  23. params: str
  24. @classmethod
  25. @abstractmethod
  26. def create_completion(
  27. cls,
  28. model: str,
  29. messages: Messages,
  30. stream: bool,
  31. **kwargs
  32. ) -> CreateResult:
  33. """
  34. Create a completion with the given parameters.
  35. Args:
  36. model (str): The model to use.
  37. messages (Messages): The messages to process.
  38. stream (bool): Whether to use streaming.
  39. **kwargs: Additional keyword arguments.
  40. Returns:
  41. CreateResult: The result of the creation process.
  42. """
  43. raise NotImplementedError()
  44. @classmethod
  45. @abstractmethod
  46. async def create_async(
  47. cls,
  48. model: str,
  49. messages: Messages,
  50. **kwargs
  51. ) -> str:
  52. """
  53. Asynchronously create a completion with the given parameters.
  54. Args:
  55. model (str): The model to use.
  56. messages (Messages): The messages to process.
  57. **kwargs: Additional keyword arguments.
  58. Returns:
  59. str: The result of the creation process.
  60. """
  61. raise NotImplementedError()
  62. @classmethod
  63. def get_dict(cls) -> Dict[str, str]:
  64. """
  65. Get a dictionary representation of the provider.
  66. Returns:
  67. Dict[str, str]: A dictionary with provider's details.
  68. """
  69. return {'name': cls.__name__, 'url': cls.url, 'label': getattr(cls, 'label', None)}
  70. class BaseRetryProvider(BaseProvider):
  71. """
  72. Base class for a provider that implements retry logic.
  73. Attributes:
  74. providers (List[Type[BaseProvider]]): List of providers to use for retries.
  75. shuffle (bool): Whether to shuffle the providers list.
  76. exceptions (Dict[str, Exception]): Dictionary of exceptions encountered.
  77. last_provider (Type[BaseProvider]): The last provider used.
  78. """
  79. __name__: str = "RetryProvider"
  80. supports_stream: bool = True
  81. last_provider: Type[BaseProvider] = None
  82. ProviderType = Union[Type[BaseProvider], BaseRetryProvider]
  83. class Streaming():
  84. def __init__(self, data: str) -> None:
  85. self.data = data
  86. def __str__(self) -> str:
  87. return self.data