__init__.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. from __future__ import annotations
  2. import random
  3. from ...typing import AsyncResult, Messages
  4. from ...providers.response import ImageResponse
  5. from ...errors import ModelNotSupportedError, MissingAuthError
  6. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
  7. from .HuggingChat import HuggingChat
  8. from .HuggingFaceAPI import HuggingFaceAPI
  9. from .HuggingFaceInference import HuggingFaceInference
  10. from .models import model_aliases, vision_models, default_vision_model
  11. from ... import debug
  12. class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
  13. url = "https://huggingface.co"
  14. login_url = "https://huggingface.co/settings/tokens"
  15. working = True
  16. supports_message_history = True
  17. @classmethod
  18. def get_models(cls) -> list[str]:
  19. if not cls.models:
  20. cls.models = HuggingFaceInference.get_models()
  21. cls.image_models = HuggingFaceInference.image_models
  22. return cls.models
  23. model_aliases = model_aliases
  24. vision_models = vision_models
  25. default_vision_model = default_vision_model
  26. @classmethod
  27. async def create_async_generator(
  28. cls,
  29. model: str,
  30. messages: Messages,
  31. **kwargs
  32. ) -> AsyncResult:
  33. if "api_key" not in kwargs and "images" not in kwargs and random.random() >= 0.5:
  34. try:
  35. is_started = False
  36. async for chunk in HuggingFaceInference.create_async_generator(model, messages, **kwargs):
  37. if isinstance(chunk, (str, ImageResponse)):
  38. is_started = True
  39. yield chunk
  40. if is_started:
  41. return
  42. except Exception as e:
  43. if is_started:
  44. raise e
  45. debug.log(f"Inference failed: {e.__class__.__name__}: {e}")
  46. if not cls.image_models:
  47. cls.get_models()
  48. if model in cls.image_models:
  49. if "api_key" not in kwargs:
  50. async for chunk in HuggingChat.create_async_generator(model, messages, **kwargs):
  51. yield chunk
  52. else:
  53. async for chunk in HuggingFaceInference.create_async_generator(model, messages, **kwargs):
  54. yield chunk
  55. return
  56. try:
  57. async for chunk in HuggingFaceAPI.create_async_generator(model, messages, **kwargs):
  58. yield chunk
  59. except (ModelNotSupportedError, MissingAuthError):
  60. async for chunk in HuggingFaceInference.create_async_generator(model, messages, **kwargs):
  61. yield chunk