12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091 |
- from __future__ import annotations
- import json
- from ...typing import AsyncResult, Messages
- from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
- from ...errors import ModelNotFoundError
- from ...requests import StreamSession, raise_for_status
- from ..HuggingChat import HuggingChat
- class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
- url = "https://huggingface.co/chat"
- working = True
- needs_auth = True
- supports_message_history = True
- default_model = HuggingChat.default_model
- models = HuggingChat.models
- model_aliases = HuggingChat.model_aliases
- @classmethod
- async def create_async_generator(
- cls,
- model: str,
- messages: Messages,
- stream: bool = True,
- proxy: str = None,
- api_base: str = "https://api-inference.huggingface.co",
- api_key: str = None,
- max_new_tokens: int = 1024,
- temperature: float = 0.7,
- **kwargs
- ) -> AsyncResult:
- model = cls.get_model(model)
- headers = {
- 'accept': '*/*',
- 'accept-language': 'en',
- 'cache-control': 'no-cache',
- 'origin': 'https://huggingface.co',
- 'pragma': 'no-cache',
- 'priority': 'u=1, i',
- 'referer': 'https://huggingface.co/chat/',
- 'sec-ch-ua': '"Not)A;Brand";v="99", "Google Chrome";v="127", "Chromium";v="127"',
- 'sec-ch-ua-mobile': '?0',
- 'sec-ch-ua-platform': '"macOS"',
- 'sec-fetch-dest': 'empty',
- 'sec-fetch-mode': 'cors',
- 'sec-fetch-site': 'same-origin',
- 'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36',
- }
- if api_key is not None:
- headers["Authorization"] = f"Bearer {api_key}"
- params = {
- "return_full_text": False,
- "max_new_tokens": max_new_tokens,
- "temperature": temperature,
- **kwargs
- }
- payload = {"inputs": format_prompt(messages), "parameters": params, "stream": stream}
- async with StreamSession(
- headers=headers,
- proxy=proxy
- ) as session:
- async with session.post(f"{api_base.rstrip('/')}/models/{model}", json=payload) as response:
- if response.status == 404:
- raise ModelNotFoundError(f"Model is not supported: {model}")
- await raise_for_status(response)
- if stream:
- first = True
- async for line in response.iter_lines():
- if line.startswith(b"data:"):
- data = json.loads(line[5:])
- if not data["token"]["special"]:
- chunk = data["token"]["text"]
- if first:
- first = False
- chunk = chunk.lstrip()
- if chunk:
- yield chunk
- else:
- yield (await response.json())[0]["generated_text"].strip()
- def format_prompt(messages: Messages) -> str:
- system_messages = [message["content"] for message in messages if message["role"] == "system"]
- question = " ".join([messages[-1]["content"], *system_messages])
- history = "".join([
- f"<s>[INST]{messages[idx-1]['content']} [/INST] {message['content']}</s>"
- for idx, message in enumerate(messages)
- if message["role"] == "assistant"
- ])
- return f"{history}<s>[INST] {question} [/INST]"
|