HuggingFace.py 3.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from __future__ import annotations
  2. import json
  3. from ...typing import AsyncResult, Messages
  4. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
  5. from ...errors import ModelNotFoundError
  6. from ...requests import StreamSession, raise_for_status
  7. from ..HuggingChat import HuggingChat
  8. class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
  9. url = "https://huggingface.co/chat"
  10. working = True
  11. needs_auth = True
  12. supports_message_history = True
  13. default_model = HuggingChat.default_model
  14. models = HuggingChat.models
  15. model_aliases = HuggingChat.model_aliases
  16. @classmethod
  17. async def create_async_generator(
  18. cls,
  19. model: str,
  20. messages: Messages,
  21. stream: bool = True,
  22. proxy: str = None,
  23. api_base: str = "https://api-inference.huggingface.co",
  24. api_key: str = None,
  25. max_new_tokens: int = 1024,
  26. temperature: float = 0.7,
  27. **kwargs
  28. ) -> AsyncResult:
  29. model = cls.get_model(model)
  30. headers = {
  31. 'accept': '*/*',
  32. 'accept-language': 'en',
  33. 'cache-control': 'no-cache',
  34. 'origin': 'https://huggingface.co',
  35. 'pragma': 'no-cache',
  36. 'priority': 'u=1, i',
  37. 'referer': 'https://huggingface.co/chat/',
  38. 'sec-ch-ua': '"Not)A;Brand";v="99", "Google Chrome";v="127", "Chromium";v="127"',
  39. 'sec-ch-ua-mobile': '?0',
  40. 'sec-ch-ua-platform': '"macOS"',
  41. 'sec-fetch-dest': 'empty',
  42. 'sec-fetch-mode': 'cors',
  43. 'sec-fetch-site': 'same-origin',
  44. 'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36',
  45. }
  46. if api_key is not None:
  47. headers["Authorization"] = f"Bearer {api_key}"
  48. params = {
  49. "return_full_text": False,
  50. "max_new_tokens": max_new_tokens,
  51. "temperature": temperature,
  52. **kwargs
  53. }
  54. payload = {"inputs": format_prompt(messages), "parameters": params, "stream": stream}
  55. async with StreamSession(
  56. headers=headers,
  57. proxy=proxy
  58. ) as session:
  59. async with session.post(f"{api_base.rstrip('/')}/models/{model}", json=payload) as response:
  60. if response.status == 404:
  61. raise ModelNotFoundError(f"Model is not supported: {model}")
  62. await raise_for_status(response)
  63. if stream:
  64. first = True
  65. async for line in response.iter_lines():
  66. if line.startswith(b"data:"):
  67. data = json.loads(line[5:])
  68. if not data["token"]["special"]:
  69. chunk = data["token"]["text"]
  70. if first:
  71. first = False
  72. chunk = chunk.lstrip()
  73. if chunk:
  74. yield chunk
  75. else:
  76. yield (await response.json())[0]["generated_text"].strip()
  77. def format_prompt(messages: Messages) -> str:
  78. system_messages = [message["content"] for message in messages if message["role"] == "system"]
  79. question = " ".join([messages[-1]["content"], *system_messages])
  80. history = "".join([
  81. f"<s>[INST]{messages[idx-1]['content']} [/INST] {message['content']}</s>"
  82. for idx, message in enumerate(messages)
  83. if message["role"] == "assistant"
  84. ])
  85. return f"{history}<s>[INST] {question} [/INST]"