OpenAssistant.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. from __future__ import annotations
  2. import json
  3. from aiohttp import ClientSession
  4. from ...typing import AsyncResult, Messages
  5. from ..base_provider import AsyncGeneratorProvider
  6. from ..helper import format_prompt, get_cookies
  7. class OpenAssistant(AsyncGeneratorProvider):
  8. url = "https://open-assistant.io/chat"
  9. needs_auth = True
  10. working = False
  11. model = "OA_SFT_Llama_30B_6"
  12. @classmethod
  13. async def create_async_generator(
  14. cls,
  15. model: str,
  16. messages: Messages,
  17. proxy: str = None,
  18. cookies: dict = None,
  19. **kwargs
  20. ) -> AsyncResult:
  21. if not cookies:
  22. cookies = get_cookies("open-assistant.io")
  23. headers = {
  24. 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/111.0.0.0 Safari/537.36',
  25. }
  26. async with ClientSession(
  27. cookies=cookies,
  28. headers=headers
  29. ) as session:
  30. async with session.post("https://open-assistant.io/api/chat", proxy=proxy) as response:
  31. chat_id = (await response.json())["id"]
  32. data = {
  33. "chat_id": chat_id,
  34. "content": f"<s>[INST]\n{format_prompt(messages)}\n[/INST]",
  35. "parent_id": None
  36. }
  37. async with session.post("https://open-assistant.io/api/chat/prompter_message", proxy=proxy, json=data) as response:
  38. parent_id = (await response.json())["id"]
  39. data = {
  40. "chat_id": chat_id,
  41. "parent_id": parent_id,
  42. "model_config_name": model if model else cls.model,
  43. "sampling_parameters":{
  44. "top_k": 50,
  45. "top_p": None,
  46. "typical_p": None,
  47. "temperature": 0.35,
  48. "repetition_penalty": 1.1111111111111112,
  49. "max_new_tokens": 1024,
  50. **kwargs
  51. },
  52. "plugins":[]
  53. }
  54. async with session.post("https://open-assistant.io/api/chat/assistant_message", proxy=proxy, json=data) as response:
  55. data = await response.json()
  56. if "id" in data:
  57. message_id = data["id"]
  58. elif "message" in data:
  59. raise RuntimeError(data["message"])
  60. else:
  61. response.raise_for_status()
  62. params = {
  63. 'chat_id': chat_id,
  64. 'message_id': message_id,
  65. }
  66. async with session.post("https://open-assistant.io/api/chat/events", proxy=proxy, params=params) as response:
  67. start = "data: "
  68. async for line in response.content:
  69. line = line.decode("utf-8")
  70. if line and line.startswith(start):
  71. line = json.loads(line[len(start):])
  72. if line["event_type"] == "token":
  73. yield line["text"]
  74. params = {
  75. 'chat_id': chat_id,
  76. }
  77. async with session.delete("https://open-assistant.io/api/chat", proxy=proxy, params=params) as response:
  78. response.raise_for_status()