HuggingChat.py 8.8 KB


  1. from __future__ import annotations
  2. import json
  3. import re
  4. import os
  5. import requests
  6. import base64
  7. from typing import AsyncIterator
  8. try:
  9. from curl_cffi.requests import Session, CurlMime
  10. has_curl_cffi = True
  11. except ImportError:
  12. has_curl_cffi = False
  13. from ..base_provider import ProviderModelMixin, AsyncAuthedProvider, AuthResult
  14. from ..helper import format_prompt, format_image_prompt, get_last_user_message
  15. from ...typing import AsyncResult, Messages, Cookies, ImagesType
  16. from ...errors import MissingRequirementsError, MissingAuthError, ResponseError
  17. from ...image import to_bytes
  18. from ...requests import get_args_from_nodriver, DEFAULT_HEADERS
  19. from ...requests.raise_for_status import raise_for_status
  20. from ...providers.response import JsonConversation, ImageResponse, Sources, TitleGeneration, Reasoning, RequestLogin
  21. from ...cookies import get_cookies
  22. from .models import default_model, fallback_models, image_models, model_aliases
  23. from ... import debug
  24. class Conversation(JsonConversation):
  25. def __init__(self, models: dict):
  26. self.models: dict = models
  27. class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
  28. url = "https://huggingface.co/chat"
  29. working = True
  30. use_nodriver = True
  31. supports_stream = True
  32. needs_auth = True
  33. default_model = default_model
  34. model_aliases = model_aliases
  35. image_models = image_models
  36. @classmethod
  37. def get_models(cls):
  38. if not cls.models:
  39. try:
  40. text = requests.get(cls.url).text
  41. text = re.sub(r',parameters:{[^}]+?}', '', text)
  42. text = re.search(r'models:(\[.+?\]),oldModels:', text).group(1)
  43. text = text.replace('void 0', 'null')
  44. def add_quotation_mark(match):
  45. return f'{match.group(1)}"{match.group(2)}":'
  46. text = re.sub(r'([{,])([A-Za-z0-9_]+?):', add_quotation_mark, text)
  47. models = json.loads(text)
  48. cls.text_models = [model["id"] for model in models]
  49. cls.models = cls.text_models + cls.image_models
  50. cls.vision_models = [model["id"] for model in models if model["multimodal"]]
  51. except Exception as e:
  52. debug.log(f"HuggingChat: Error reading models: {type(e).__name__}: {e}")
  53. cls.models = [*fallback_models]
  54. return cls.models
  55. @classmethod
  56. async def on_auth_async(cls, cookies: Cookies = None, proxy: str = None, **kwargs) -> AsyncIterator:
  57. if cookies is None:
  58. cookies = get_cookies("huggingface.co", single_browser=True)
  59. if "hf-chat" in cookies:
  60. yield AuthResult(
  61. cookies=cookies,
  62. impersonate="chrome",
  63. headers=DEFAULT_HEADERS
  64. )
  65. return
  66. yield RequestLogin(cls.__name__, os.environ.get("G4F_LOGIN_URL") or "")
  67. yield AuthResult(
  68. **await get_args_from_nodriver(
  69. cls.url,
  70. proxy=proxy,
  71. wait_for='form[action="/chat/logout"]'
  72. )
  73. )
  74. @classmethod
  75. async def create_authed(
  76. cls,
  77. model: str,
  78. messages: Messages,
  79. auth_result: AuthResult,
  80. prompt: str = None,
  81. images: ImagesType = None,
  82. return_conversation: bool = False,
  83. conversation: Conversation = None,
  84. web_search: bool = False,
  85. **kwargs
  86. ) -> AsyncResult:
  87. if not has_curl_cffi:
  88. raise MissingRequirementsError('Install "curl_cffi" package | pip install -U curl_cffi')
  89. model = cls.get_model(model)
  90. session = Session(**auth_result.get_dict())
  91. if conversation is None or not hasattr(conversation, "models"):
  92. conversation = Conversation({})
  93. if model not in conversation.models:
  94. conversationId = cls.create_conversation(session, model)
  95. debug.log(f"Conversation created: {json.dumps(conversationId[8:] + '...')}")
  96. messageId = cls.fetch_message_id(session, conversationId)
  97. conversation.models[model] = {"conversationId": conversationId, "messageId": messageId}
  98. if return_conversation:
  99. yield conversation
  100. inputs = format_prompt(messages)
  101. else:
  102. conversationId = conversation.models[model]["conversationId"]
  103. conversation.models[model]["messageId"] = cls.fetch_message_id(session, conversationId)
  104. inputs = get_last_user_message(messages)
  105. settings = {
  106. "inputs": inputs,
  107. "id": conversation.models[model]["messageId"],
  108. "is_retry": False,
  109. "is_continue": False,
  110. "web_search": web_search,
  111. "tools": ["000000000000000000000001"] if model in cls.image_models else [],
  112. }
  113. headers = {
  114. 'accept': '*/*',
  115. 'origin': 'https://huggingface.co',
  116. 'referer': f'https://huggingface.co/chat/conversation/{conversationId}',
  117. }
  118. data = CurlMime()
  119. data.addpart('data', data=json.dumps(settings, separators=(',', ':')))
  120. if images is not None:
  121. for image, filename in images:
  122. data.addpart(
  123. "files",
  124. filename=f"base64;{filename}",
  125. data=base64.b64encode(to_bytes(image))
  126. )
  127. response = session.post(
  128. f'https://huggingface.co/chat/conversation/{conversationId}',
  129. headers=headers,
  130. multipart=data,
  131. stream=True
  132. )
  133. raise_for_status(response)
  134. sources = None
  135. for line in response.iter_lines():
  136. if not line:
  137. continue
  138. try:
  139. line = json.loads(line)
  140. except json.JSONDecodeError as e:
  141. debug.log(f"Failed to decode JSON: {line}, error: {e}")
  142. continue
  143. if "type" not in line:
  144. raise RuntimeError(f"Response: {line}")
  145. elif line["type"] == "stream":
  146. yield line["token"].replace('\u0000', '')
  147. elif line["type"] == "finalAnswer":
  148. break
  149. elif line["type"] == "file":
  150. url = f"https://huggingface.co/chat/conversation/{conversationId}/output/{line['sha']}"
  151. yield ImageResponse(url, format_image_prompt(messages, prompt), options={"cookies": auth_result.cookies})
  152. elif line["type"] == "webSearch" and "sources" in line:
  153. sources = Sources(line["sources"])
  154. elif line["type"] == "title":
  155. yield TitleGeneration(line["title"])
  156. elif line["type"] == "reasoning":
  157. yield Reasoning(line.get("token"), line.get("status"))
  158. if sources is not None:
  159. yield sources
  160. @classmethod
  161. def create_conversation(cls, session: Session, model: str):
  162. if model in cls.image_models:
  163. model = cls.default_model
  164. json_data = {
  165. 'model': model,
  166. }
  167. response = session.post('https://huggingface.co/chat/conversation', json=json_data)
  168. if response.status_code == 401:
  169. raise MissingAuthError(response.text)
  170. if response.status_code == 400:
  171. raise ResponseError(f"{response.text}: Model: {model}")
  172. raise_for_status(response)
  173. return response.json().get('conversationId')
  174. @classmethod
  175. def fetch_message_id(cls, session: Session, conversation_id: str):
  176. # Get the data response and parse it properly
  177. response = session.get(f'https://huggingface.co/chat/conversation/{conversation_id}/__data.json?x-sveltekit-invalidated=11')
  178. raise_for_status(response)
  179. # Split the response content by newlines and parse each line as JSON
  180. try:
  181. json_data = None
  182. for line in response.text.split('\n'):
  183. if line.strip():
  184. try:
  185. parsed = json.loads(line)
  186. if isinstance(parsed, dict) and "nodes" in parsed:
  187. json_data = parsed
  188. break
  189. except json.JSONDecodeError:
  190. continue
  191. if not json_data:
  192. raise RuntimeError("Failed to parse response data")
  193. if json_data["nodes"][-1]["type"] == "error":
  194. if json_data["nodes"][-1]["status"] == 403:
  195. raise MissingAuthError(json_data["nodes"][-1]["error"]["message"])
  196. raise ResponseError(json.dumps(json_data["nodes"][-1]))
  197. data = json_data["nodes"][1]["data"]
  198. keys = data[data[0]["messages"]]
  199. message_keys = data[keys[-1]]
  200. return data[message_keys["id"]]
  201. except (KeyError, IndexError, TypeError) as e:
  202. raise RuntimeError(f"Failed to extract message ID: {str(e)}")