PerplexityAi.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. from __future__ import annotations
  2. import time
  3. from selenium.webdriver.common.by import By
  4. from selenium.webdriver.support.ui import WebDriverWait
  5. from selenium.webdriver.support import expected_conditions as EC
  6. from selenium.webdriver.common.keys import Keys
  7. from ..typing import CreateResult, Messages
  8. from .base_provider import AbstractProvider
  9. from .helper import format_prompt
  10. from ..webdriver import WebDriver, WebDriverSession
  11. class PerplexityAi(AbstractProvider):
  12. url = "https://www.perplexity.ai"
  13. working = True
  14. supports_gpt_35_turbo = True
  15. supports_stream = True
  16. @classmethod
  17. def create_completion(
  18. cls,
  19. model: str,
  20. messages: Messages,
  21. stream: bool,
  22. proxy: str = None,
  23. timeout: int = 120,
  24. webdriver: WebDriver = None,
  25. virtual_display: bool = True,
  26. copilot: bool = False,
  27. **kwargs
  28. ) -> CreateResult:
  29. with WebDriverSession(webdriver, "", virtual_display=virtual_display, proxy=proxy) as driver:
  30. prompt = format_prompt(messages)
  31. driver.get(f"{cls.url}/")
  32. wait = WebDriverWait(driver, timeout)
  33. # Is page loaded?
  34. wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "textarea[placeholder='Ask anything...']")))
  35. # Register WebSocket hook
  36. script = """
  37. window._message = window._last_message = "";
  38. window._message_finished = false;
  39. const _socket_send = WebSocket.prototype.send;
  40. WebSocket.prototype.send = function(...args) {
  41. if (!window.socket_onmessage) {
  42. window._socket_onmessage = this;
  43. this.addEventListener("message", (event) => {
  44. if (event.data.startsWith("42")) {
  45. let data = JSON.parse(event.data.substring(2));
  46. if (data[0] =="query_progress" || data[0] == "query_answered") {
  47. let content = JSON.parse(data[1]["text"]);
  48. if (data[1]["mode"] == "copilot") {
  49. content = content[content.length-1]["content"]["answer"];
  50. content = JSON.parse(content);
  51. }
  52. window._message = content["answer"];
  53. if (!window._message_finished) {
  54. window._message_finished = data[0] == "query_answered";
  55. }
  56. }
  57. }
  58. });
  59. }
  60. return _socket_send.call(this, ...args);
  61. };
  62. """
  63. driver.execute_script(script)
  64. if copilot:
  65. try:
  66. # Check for account
  67. driver.find_element(By.CSS_SELECTOR, "img[alt='User avatar']")
  68. # Enable copilot
  69. driver.find_element(By.CSS_SELECTOR, "button[data-testid='copilot-toggle']").click()
  70. except:
  71. raise RuntimeError("You need a account for copilot")
  72. # Submit prompt
  73. driver.find_element(By.CSS_SELECTOR, "textarea[placeholder='Ask anything...']").send_keys(prompt)
  74. driver.find_element(By.CSS_SELECTOR, "textarea[placeholder='Ask anything...']").send_keys(Keys.ENTER)
  75. # Stream response
  76. script = """
  77. if(window._message && window._message != window._last_message) {
  78. try {
  79. return window._message.substring(window._last_message.length);
  80. } finally {
  81. window._last_message = window._message;
  82. }
  83. } else if(window._message_finished) {
  84. return null;
  85. } else {
  86. return '';
  87. }
  88. """
  89. while True:
  90. chunk = driver.execute_script(script)
  91. if chunk:
  92. yield chunk
  93. elif chunk != "":
  94. break
  95. else:
  96. time.sleep(0.1)