Theb.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. from __future__ import annotations
  2. import time
  3. from ...typing import CreateResult, Messages
  4. from ..base_provider import AbstractProvider
  5. from ..helper import format_prompt
  6. from ...webdriver import WebDriver, WebDriverSession, element_send_text
  7. models = {
  8. "theb-ai": "TheB.AI",
  9. "theb-ai-free": "TheB.AI Free",
  10. "gpt-3.5-turbo": "GPT-3.5 Turbo (New)",
  11. "gpt-3.5-turbo-16k": "GPT-3.5-16K",
  12. "gpt-4-turbo": "GPT-4 Turbo",
  13. "gpt-4": "GPT-4",
  14. "gpt-4-32k": "GPT-4 32K",
  15. "claude-2": "Claude 2",
  16. "claude-instant-1": "Claude Instant 1.2",
  17. "palm-2": "PaLM 2",
  18. "palm-2-32k": "PaLM 2 32K",
  19. "palm-2-codey": "Codey",
  20. "palm-2-codey-32k": "Codey 32K",
  21. "vicuna-13b-v1.5": "Vicuna v1.5 13B",
  22. "llama-2-7b-chat": "Llama 2 7B",
  23. "llama-2-13b-chat": "Llama 2 13B",
  24. "llama-2-70b-chat": "Llama 2 70B",
  25. "code-llama-7b": "Code Llama 7B",
  26. "code-llama-13b": "Code Llama 13B",
  27. "code-llama-34b": "Code Llama 34B",
  28. "qwen-7b-chat": "Qwen 7B"
  29. }
  30. class Theb(AbstractProvider):
  31. label = "TheB.AI"
  32. url = "https://beta.theb.ai"
  33. working = True
  34. supports_gpt_35_turbo = True
  35. supports_gpt_4 = True
  36. supports_stream = True
  37. models = models.keys()
  38. @classmethod
  39. def create_completion(
  40. cls,
  41. model: str,
  42. messages: Messages,
  43. stream: bool,
  44. proxy: str = None,
  45. webdriver: WebDriver = None,
  46. virtual_display: bool = True,
  47. **kwargs
  48. ) -> CreateResult:
  49. if model in models:
  50. model = models[model]
  51. prompt = format_prompt(messages)
  52. web_session = WebDriverSession(webdriver, virtual_display=virtual_display, proxy=proxy)
  53. with web_session as driver:
  54. from selenium.webdriver.common.by import By
  55. from selenium.webdriver.support.ui import WebDriverWait
  56. from selenium.webdriver.support import expected_conditions as EC
  57. from selenium.webdriver.common.keys import Keys
  58. # Register fetch hook
  59. script = """
  60. window._fetch = window.fetch;
  61. window.fetch = async (url, options) => {
  62. // Call parent fetch method
  63. const response = await window._fetch(url, options);
  64. if (!url.startsWith("/api/conversation")) {
  65. return result;
  66. }
  67. // Copy response
  68. copy = response.clone();
  69. window._reader = response.body.pipeThrough(new TextDecoderStream()).getReader();
  70. return copy;
  71. }
  72. window._last_message = "";
  73. """
  74. driver.execute_cdp_cmd("Page.addScriptToEvaluateOnNewDocument", {
  75. "source": script
  76. })
  77. try:
  78. driver.get(f"{cls.url}/home")
  79. wait = WebDriverWait(driver, 5)
  80. wait.until(EC.visibility_of_element_located((By.ID, "textareaAutosize")))
  81. except:
  82. driver = web_session.reopen()
  83. driver.execute_cdp_cmd("Page.addScriptToEvaluateOnNewDocument", {
  84. "source": script
  85. })
  86. driver.get(f"{cls.url}/home")
  87. wait = WebDriverWait(driver, 240)
  88. wait.until(EC.visibility_of_element_located((By.ID, "textareaAutosize")))
  89. try:
  90. driver.find_element(By.CSS_SELECTOR, ".driver-overlay").click()
  91. driver.find_element(By.CSS_SELECTOR, ".driver-overlay").click()
  92. except:
  93. pass
  94. if model:
  95. # Load model panel
  96. wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "#SelectModel svg")))
  97. time.sleep(0.1)
  98. driver.find_element(By.CSS_SELECTOR, "#SelectModel svg").click()
  99. try:
  100. driver.find_element(By.CSS_SELECTOR, ".driver-overlay").click()
  101. driver.find_element(By.CSS_SELECTOR, ".driver-overlay").click()
  102. except:
  103. pass
  104. # Select model
  105. selector = f"div.flex-col div.items-center span[title='{model}']"
  106. wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, selector)))
  107. span = driver.find_element(By.CSS_SELECTOR, selector)
  108. container = span.find_element(By.XPATH, "//div/../..")
  109. button = container.find_element(By.CSS_SELECTOR, "button.btn-blue.btn-small.border")
  110. button.click()
  111. # Submit prompt
  112. wait.until(EC.visibility_of_element_located((By.ID, "textareaAutosize")))
  113. element_send_text(driver.find_element(By.ID, "textareaAutosize"), prompt)
  114. # Read response with reader
  115. script = """
  116. if(window._reader) {
  117. chunk = await window._reader.read();
  118. if (chunk['done']) {
  119. return null;
  120. }
  121. message = '';
  122. chunk['value'].split('\\r\\n').forEach((line, index) => {
  123. if (line.startsWith('data: ')) {
  124. try {
  125. line = JSON.parse(line.substring('data: '.length));
  126. message = line["args"]["content"];
  127. } catch(e) { }
  128. }
  129. });
  130. if (message) {
  131. try {
  132. return message.substring(window._last_message.length);
  133. } finally {
  134. window._last_message = message;
  135. }
  136. }
  137. }
  138. return '';
  139. """
  140. while True:
  141. chunk = driver.execute_script(script)
  142. if chunk:
  143. yield chunk
  144. elif chunk != "":
  145. break
  146. else:
  147. time.sleep(0.1)