Pi.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. from __future__ import annotations
  2. import json
  3. from ..typing import AsyncResult, Messages, Cookies
  4. from .base_provider import AsyncGeneratorProvider, format_prompt
  5. from ..requests import StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies
  6. class Pi(AsyncGeneratorProvider):
  7. url = "https://pi.ai/talk"
  8. working = True
  9. supports_stream = True
  10. default_model = "pi"
  11. models = [default_model]
  12. _headers: dict = None
  13. _cookies: Cookies = {}
  14. @classmethod
  15. async def create_async_generator(
  16. cls,
  17. model: str,
  18. messages: Messages,
  19. stream: bool,
  20. proxy: str = None,
  21. timeout: int = 180,
  22. conversation_id: str = None,
  23. **kwargs
  24. ) -> AsyncResult:
  25. if cls._headers is None:
  26. args = await get_args_from_nodriver(cls.url, proxy=proxy, timeout=timeout)
  27. cls._cookies = args.get("cookies", {})
  28. cls._headers = args.get("headers")
  29. async with StreamSession(headers=cls._headers, cookies=cls._cookies, proxy=proxy) as session:
  30. if not conversation_id:
  31. conversation_id = await cls.start_conversation(session)
  32. prompt = format_prompt(messages)
  33. else:
  34. prompt = messages[-1]["content"]
  35. answer = cls.ask(session, prompt, conversation_id)
  36. async for line in answer:
  37. if "text" in line:
  38. yield line["text"]
  39. @classmethod
  40. async def start_conversation(cls, session: StreamSession) -> str:
  41. async with session.post('https://pi.ai/api/chat/start', data="{}", headers={
  42. 'accept': 'application/json',
  43. 'x-api-version': '3'
  44. }) as response:
  45. await raise_for_status(response)
  46. return (await response.json())['conversations'][0]['sid']
  47. async def get_chat_history(session: StreamSession, conversation_id: str):
  48. params = {
  49. 'conversation': conversation_id,
  50. }
  51. async with session.get('https://pi.ai/api/chat/history', params=params) as response:
  52. await raise_for_status(response)
  53. return await response.json()
  54. @classmethod
  55. async def ask(cls, session: StreamSession, prompt: str, conversation_id: str):
  56. json_data = {
  57. 'text': prompt,
  58. 'conversation': conversation_id,
  59. 'mode': 'BASE',
  60. }
  61. async with session.post('https://pi.ai/api/chat', json=json_data) as response:
  62. await raise_for_status(response)
  63. cls._cookies = merge_cookies(cls._cookies, response)
  64. async for line in response.iter_lines():
  65. if line.startswith(b'data: {"text":'):
  66. yield json.loads(line.split(b'data: ')[1])
  67. elif line.startswith(b'data: {"title":'):
  68. yield json.loads(line.split(b'data: ')[1])