H2o.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. from __future__ import annotations
  2. import json
  3. import uuid
  4. from aiohttp import ClientSession
  5. from ...typing import AsyncResult, Messages
  6. from ..base_provider import AsyncGeneratorProvider, format_prompt
  7. class H2o(AsyncGeneratorProvider):
  8. url = "https://gpt-gm.h2o.ai"
  9. model = "h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v1"
  10. @classmethod
  11. async def create_async_generator(
  12. cls,
  13. model: str,
  14. messages: Messages,
  15. proxy: str = None,
  16. **kwargs
  17. ) -> AsyncResult:
  18. model = model if model else cls.model
  19. headers = {"Referer": f"{cls.url}/"}
  20. async with ClientSession(
  21. headers=headers
  22. ) as session:
  23. data = {
  24. "ethicsModalAccepted": "true",
  25. "shareConversationsWithModelAuthors": "true",
  26. "ethicsModalAcceptedAt": "",
  27. "activeModel": model,
  28. "searchEnabled": "true",
  29. }
  30. async with session.post(
  31. f"{cls.url}/settings",
  32. proxy=proxy,
  33. data=data
  34. ) as response:
  35. response.raise_for_status()
  36. async with session.post(
  37. f"{cls.url}/conversation",
  38. proxy=proxy,
  39. json={"model": model},
  40. ) as response:
  41. response.raise_for_status()
  42. conversationId = (await response.json())["conversationId"]
  43. data = {
  44. "inputs": format_prompt(messages),
  45. "parameters": {
  46. "temperature": 0.4,
  47. "truncate": 2048,
  48. "max_new_tokens": 1024,
  49. "do_sample": True,
  50. "repetition_penalty": 1.2,
  51. "return_full_text": False,
  52. **kwargs
  53. },
  54. "stream": True,
  55. "options": {
  56. "id": str(uuid.uuid4()),
  57. "response_id": str(uuid.uuid4()),
  58. "is_retry": False,
  59. "use_cache": False,
  60. "web_search_id": "",
  61. },
  62. }
  63. async with session.post(
  64. f"{cls.url}/conversation/{conversationId}",
  65. proxy=proxy,
  66. json=data
  67. ) as response:
  68. start = "data:"
  69. async for line in response.content:
  70. line = line.decode("utf-8")
  71. if line and line.startswith(start):
  72. line = json.loads(line[len(start):-1])
  73. if not line["token"]["special"]:
  74. yield line["token"]["text"]
  75. async with session.delete(
  76. f"{cls.url}/conversation/{conversationId}",
  77. proxy=proxy,
  78. ) as response:
  79. response.raise_for_status()