types.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. from __future__ import annotations
  2. from abc import ABC, abstractmethod
  3. from typing import Union, Dict, Type
  4. from ..typing import Messages, CreateResult
  5. class BaseProvider(ABC):
  6. """
  7. Abstract base class for a provider.
  8. Attributes:
  9. url (str): URL of the provider.
  10. working (bool): Indicates if the provider is currently working.
  11. needs_auth (bool): Indicates if the provider needs authentication.
  12. supports_stream (bool): Indicates if the provider supports streaming.
  13. supports_message_history (bool): Indicates if the provider supports message history.
  14. supports_system_message (bool): Indicates if the provider supports system messages.
  15. params (str): List parameters for the provider.
  16. """
  17. url: str = None
  18. working: bool = False
  19. needs_auth: bool = False
  20. supports_stream: bool = False
  21. supports_message_history: bool = False
  22. supports_system_message: bool = False
  23. params: str
  24. @abstractmethod
  25. def get_create_function() -> callable:
  26. """
  27. Get the create function for the provider.
  28. Returns:
  29. callable: The create function.
  30. """
  31. raise NotImplementedError()
  32. @abstractmethod
  33. def get_async_create_function() -> callable:
  34. """
  35. Get the async create function for the provider.
  36. Returns:
  37. callable: The create function.
  38. """
  39. raise NotImplementedError()
  40. @classmethod
  41. def get_dict(cls) -> Dict[str, str]:
  42. """
  43. Get a dictionary representation of the provider.
  44. Returns:
  45. Dict[str, str]: A dictionary with provider's details.
  46. """
  47. return {'name': cls.__name__, 'url': cls.url, 'label': getattr(cls, 'label', None)}
  48. class BaseRetryProvider(BaseProvider):
  49. """
  50. Base class for a provider that implements retry logic.
  51. Attributes:
  52. providers (List[Type[BaseProvider]]): List of providers to use for retries.
  53. shuffle (bool): Whether to shuffle the providers list.
  54. exceptions (Dict[str, Exception]): Dictionary of exceptions encountered.
  55. last_provider (Type[BaseProvider]): The last provider used.
  56. """
  57. __name__: str = "RetryProvider"
  58. supports_stream: bool = True
  59. last_provider: Type[BaseProvider] = None
  60. ProviderType = Union[Type[BaseProvider], BaseRetryProvider]
  61. class Streaming():
  62. def __init__(self, data: str) -> None:
  63. self.data = data
  64. def __str__(self) -> str:
  65. return self.data