Pi.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. from __future__ import annotations
  2. import json
  3. from ..typing import AsyncResult, Messages, Cookies
  4. from .base_provider import AsyncGeneratorProvider, format_prompt
  5. from ..requests import StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies
  6. class Pi(AsyncGeneratorProvider):
  7. url = "https://pi.ai/talk"
  8. working = True
  9. use_nodriver = True
  10. supports_stream = True
  11. use_nodriver = True
  12. default_model = "pi"
  13. models = [default_model]
  14. _headers: dict = None
  15. _cookies: Cookies = {}
  16. @classmethod
  17. async def create_async_generator(
  18. cls,
  19. model: str,
  20. messages: Messages,
  21. stream: bool,
  22. proxy: str = None,
  23. timeout: int = 180,
  24. conversation_id: str = None,
  25. **kwargs
  26. ) -> AsyncResult:
  27. if cls._headers is None:
  28. args = await get_args_from_nodriver(cls.url, proxy=proxy, timeout=timeout)
  29. cls._cookies = args.get("cookies", {})
  30. cls._headers = args.get("headers")
  31. async with StreamSession(headers=cls._headers, cookies=cls._cookies, proxy=proxy) as session:
  32. if not conversation_id:
  33. conversation_id = await cls.start_conversation(session)
  34. prompt = format_prompt(messages)
  35. else:
  36. prompt = messages[-1]["content"]
  37. answer = cls.ask(session, prompt, conversation_id)
  38. async for line in answer:
  39. if "text" in line:
  40. yield line["text"]
  41. @classmethod
  42. async def start_conversation(cls, session: StreamSession) -> str:
  43. async with session.post('https://pi.ai/api/chat/start', data="{}", headers={
  44. 'accept': 'application/json',
  45. 'x-api-version': '3'
  46. }) as response:
  47. await raise_for_status(response)
  48. return (await response.json())['conversations'][0]['sid']
  49. async def get_chat_history(session: StreamSession, conversation_id: str):
  50. params = {
  51. 'conversation': conversation_id,
  52. }
  53. async with session.get('https://pi.ai/api/chat/history', params=params) as response:
  54. await raise_for_status(response)
  55. return await response.json()
  56. @classmethod
  57. async def ask(cls, session: StreamSession, prompt: str, conversation_id: str):
  58. json_data = {
  59. 'text': prompt,
  60. 'conversation': conversation_id,
  61. 'mode': 'BASE',
  62. }
  63. async with session.post('https://pi.ai/api/chat', json=json_data) as response:
  64. await raise_for_status(response)
  65. cls._cookies = merge_cookies(cls._cookies, response)
  66. async for line in response.iter_lines():
  67. if line.startswith(b'data: {"text":'):
  68. yield json.loads(line.split(b'data: ')[1])
  69. elif line.startswith(b'data: {"title":'):
  70. yield json.loads(line.split(b'data: ')[1])