HuggingChat.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. from __future__ import annotations
  2. import json, uuid
  3. from aiohttp import ClientSession
  4. from ...typing import AsyncResult, Messages
  5. from ..base_provider import AsyncGeneratorProvider
  6. from ..helper import format_prompt, get_cookies
  7. class HuggingChat(AsyncGeneratorProvider):
  8. url = "https://huggingface.co/chat"
  9. working = True
  10. model = "meta-llama/Llama-2-70b-chat-hf"
  11. @classmethod
  12. async def create_async_generator(
  13. cls,
  14. model: str,
  15. messages: Messages,
  16. stream: bool = True,
  17. proxy: str = None,
  18. web_search: bool = False,
  19. cookies: dict = None,
  20. **kwargs
  21. ) -> AsyncResult:
  22. model = model if model else cls.model
  23. if not cookies:
  24. cookies = get_cookies(".huggingface.co")
  25. headers = {
  26. 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36',
  27. }
  28. async with ClientSession(
  29. cookies=cookies,
  30. headers=headers
  31. ) as session:
  32. async with session.post(f"{cls.url}/conversation", json={"model": model}, proxy=proxy) as response:
  33. conversation_id = (await response.json())["conversationId"]
  34. send = {
  35. "id": str(uuid.uuid4()),
  36. "inputs": format_prompt(messages),
  37. "is_retry": False,
  38. "response_id": str(uuid.uuid4()),
  39. "web_search": web_search
  40. }
  41. async with session.post(f"{cls.url}/conversation/{conversation_id}", json=send, proxy=proxy) as response:
  42. async for line in response.content:
  43. line = json.loads(line[:-1])
  44. if "type" not in line:
  45. raise RuntimeError(f"Response: {line}")
  46. elif line["type"] == "stream":
  47. yield line["token"]
  48. elif line["type"] == "finalAnswer":
  49. break
  50. async with session.delete(f"{cls.url}/conversation/{conversation_id}", proxy=proxy) as response:
  51. response.raise_for_status()