Pi.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. from __future__ import annotations
  2. import json
  3. from ..typing import CreateResult, Messages
  4. from .base_provider import AbstractProvider, format_prompt
  5. from ..requests import Session, get_session_from_browser
  6. class Pi(AbstractProvider):
  7. url = "https://pi.ai/talk"
  8. working = True
  9. supports_stream = True
  10. @classmethod
  11. def create_completion(
  12. cls,
  13. model: str,
  14. messages: Messages,
  15. stream: bool,
  16. session: Session = None,
  17. proxy: str = None,
  18. timeout: int = 180,
  19. conversation_id: str = None,
  20. **kwargs
  21. ) -> CreateResult:
  22. if not session:
  23. session = get_session_from_browser(url=cls.url, proxy=proxy, timeout=timeout)
  24. if not conversation_id:
  25. conversation_id = cls.start_conversation(session)
  26. prompt = format_prompt(messages)
  27. else:
  28. prompt = messages[-1]["content"]
  29. answer = cls.ask(session, prompt, conversation_id)
  30. for line in answer:
  31. if "text" in line:
  32. yield line["text"]
  33. @classmethod
  34. def start_conversation(cls, session: Session) -> str:
  35. response = session.post('https://pi.ai/api/chat/start', data="{}", headers={
  36. 'accept': 'application/json',
  37. 'x-api-version': '3'
  38. })
  39. if 'Just a moment' in response.text:
  40. raise RuntimeError('Error: Cloudflare detected')
  41. return response.json()['conversations'][0]['sid']
  42. def get_chat_history(session: Session, conversation_id: str):
  43. params = {
  44. 'conversation': conversation_id,
  45. }
  46. response = session.get('https://pi.ai/api/chat/history', params=params)
  47. if 'Just a moment' in response.text:
  48. raise RuntimeError('Error: Cloudflare detected')
  49. return response.json()
  50. def ask(session: Session, prompt: str, conversation_id: str):
  51. json_data = {
  52. 'text': prompt,
  53. 'conversation': conversation_id,
  54. 'mode': 'BASE',
  55. }
  56. response = session.post('https://pi.ai/api/chat', json=json_data, stream=True)
  57. for line in response.iter_lines():
  58. if b'Just a moment' in line:
  59. raise RuntimeError('Error: Cloudflare detected')
  60. if line.startswith(b'data: {"text":'):
  61. yield json.loads(line.split(b'data: ')[1])
  62. elif line.startswith(b'data: {"title":'):
  63. yield json.loads(line.split(b'data: ')[1])