Poe.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. from __future__ import annotations
  2. import time
  3. from ...typing import CreateResult, Messages
  4. from ..base_provider import AbstractProvider
  5. from ..helper import format_prompt
  6. from ...webdriver import WebDriver, WebDriverSession, element_send_text
  7. models = {
  8. "meta-llama/Llama-2-7b-chat-hf": {"name": "Llama-2-7b"},
  9. "meta-llama/Llama-2-13b-chat-hf": {"name": "Llama-2-13b"},
  10. "meta-llama/Llama-2-70b-chat-hf": {"name": "Llama-2-70b"},
  11. "codellama/CodeLlama-7b-Instruct-hf": {"name": "Code-Llama-7b"},
  12. "codellama/CodeLlama-13b-Instruct-hf": {"name": "Code-Llama-13b"},
  13. "codellama/CodeLlama-34b-Instruct-hf": {"name": "Code-Llama-34b"},
  14. "gpt-3.5-turbo": {"name": "GPT-3.5-Turbo"},
  15. "gpt-3.5-turbo-instruct": {"name": "GPT-3.5-Turbo-Instruct"},
  16. "gpt-4": {"name": "GPT-4"},
  17. "palm": {"name": "Google-PaLM"},
  18. }
  19. class Poe(AbstractProvider):
  20. url = "https://poe.com"
  21. working = True
  22. needs_auth = True
  23. supports_gpt_35_turbo = True
  24. supports_stream = True
  25. models = models.keys()
  26. @classmethod
  27. def create_completion(
  28. cls,
  29. model: str,
  30. messages: Messages,
  31. stream: bool,
  32. proxy: str = None,
  33. webdriver: WebDriver = None,
  34. user_data_dir: str = None,
  35. headless: bool = True,
  36. **kwargs
  37. ) -> CreateResult:
  38. if not model:
  39. model = "gpt-3.5-turbo"
  40. elif model not in models:
  41. raise ValueError(f"Model are not supported: {model}")
  42. prompt = format_prompt(messages)
  43. session = WebDriverSession(webdriver, user_data_dir, headless, proxy=proxy)
  44. with session as driver:
  45. from selenium.webdriver.common.by import By
  46. from selenium.webdriver.support.ui import WebDriverWait
  47. from selenium.webdriver.support import expected_conditions as EC
  48. driver.execute_cdp_cmd("Page.addScriptToEvaluateOnNewDocument", {
  49. "source": """
  50. window._message = window._last_message = "";
  51. window._message_finished = false;
  52. class ProxiedWebSocket extends WebSocket {
  53. constructor(url, options) {
  54. super(url, options);
  55. this.addEventListener("message", (e) => {
  56. const data = JSON.parse(JSON.parse(e.data)["messages"][0])["payload"]["data"];
  57. if ("messageAdded" in data) {
  58. if (data["messageAdded"]["author"] != "human") {
  59. window._message = data["messageAdded"]["text"];
  60. if (data["messageAdded"]["state"] == "complete") {
  61. window._message_finished = true;
  62. }
  63. }
  64. }
  65. });
  66. }
  67. }
  68. window.WebSocket = ProxiedWebSocket;
  69. """
  70. })
  71. try:
  72. driver.get(f"{cls.url}/{models[model]['name']}")
  73. wait = WebDriverWait(driver, 10 if headless else 240)
  74. wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "textarea[class^='GrowingTextArea']")))
  75. except:
  76. # Reopen browser for login
  77. if not webdriver:
  78. driver = session.reopen()
  79. driver.get(f"{cls.url}/{models[model]['name']}")
  80. wait = WebDriverWait(driver, 240)
  81. wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "textarea[class^='GrowingTextArea']")))
  82. else:
  83. raise RuntimeError("Prompt textarea not found. You may not be logged in.")
  84. element_send_text(driver.find_element(By.CSS_SELECTOR, "footer textarea[class^='GrowingTextArea']"), prompt)
  85. driver.find_element(By.CSS_SELECTOR, "footer button[class*='ChatMessageSendButton']").click()
  86. script = """
  87. if(window._message && window._message != window._last_message) {
  88. try {
  89. return window._message.substring(window._last_message.length);
  90. } finally {
  91. window._last_message = window._message;
  92. }
  93. } else if(window._message_finished) {
  94. return null;
  95. } else {
  96. return '';
  97. }
  98. """
  99. while True:
  100. chunk = driver.execute_script(script)
  101. if chunk:
  102. yield chunk
  103. elif chunk != "":
  104. break
  105. else:
  106. time.sleep(0.1)