Blackbox2.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. from __future__ import annotations
  2. import random
  3. import asyncio
  4. import re
  5. import json
  6. from pathlib import Path
  7. from aiohttp import ClientSession
  8. from typing import AsyncIterator
  9. from ..typing import AsyncResult, Messages
  10. from ..image import ImageResponse
  11. from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
  12. from ..cookies import get_cookies_dir
  13. from .. import debug
  14. class Blackbox2(AsyncGeneratorProvider, ProviderModelMixin):
  15. url = "https://www.blackbox.ai"
  16. api_endpoints = {
  17. "llama-3.1-70b": "https://www.blackbox.ai/api/improve-prompt",
  18. "flux": "https://www.blackbox.ai/api/image-generator"
  19. }
  20. working = True
  21. supports_system_message = True
  22. supports_message_history = True
  23. supports_stream = False
  24. default_model = 'llama-3.1-70b'
  25. chat_models = ['llama-3.1-70b']
  26. image_models = ['flux']
  27. models = [*chat_models, *image_models]
  28. @classmethod
  29. def _get_cache_file(cls) -> Path:
  30. """Returns the path to the cache file."""
  31. dir = Path(get_cookies_dir())
  32. dir.mkdir(exist_ok=True)
  33. return dir / 'blackbox2.json'
  34. @classmethod
  35. def _load_cached_license(cls) -> str | None:
  36. """Loads the license key from the cache."""
  37. cache_file = cls._get_cache_file()
  38. if cache_file.exists():
  39. try:
  40. with open(cache_file, 'r') as f:
  41. data = json.load(f)
  42. return data.get('license_key')
  43. except Exception as e:
  44. debug.log(f"Error reading cache file: {e}")
  45. return None
  46. @classmethod
  47. def _save_cached_license(cls, license_key: str):
  48. """Saves the license key to the cache."""
  49. cache_file = cls._get_cache_file()
  50. try:
  51. with open(cache_file, 'w') as f:
  52. json.dump({'license_key': license_key}, f)
  53. except Exception as e:
  54. debug.log(f"Error writing to cache file: {e}")
  55. @classmethod
  56. async def _get_license_key(cls, session: ClientSession) -> str:
  57. cached_license = cls._load_cached_license()
  58. if cached_license:
  59. return cached_license
  60. try:
  61. async with session.get(cls.url) as response:
  62. html = await response.text()
  63. js_files = re.findall(r'static/chunks/\d{4}-[a-fA-F0-9]+\.js', html)
  64. license_format = r'["\'](\d{6}-\d{6}-\d{6}-\d{6}-\d{6})["\']'
  65. def is_valid_context(text_around):
  66. return any(char + '=' in text_around for char in 'abcdefghijklmnopqrstuvwxyz')
  67. for js_file in js_files:
  68. js_url = f"{cls.url}/_next/{js_file}"
  69. async with session.get(js_url) as js_response:
  70. js_content = await js_response.text()
  71. for match in re.finditer(license_format, js_content):
  72. start = max(0, match.start() - 10)
  73. end = min(len(js_content), match.end() + 10)
  74. context = js_content[start:end]
  75. if is_valid_context(context):
  76. license_key = match.group(1)
  77. cls._save_cached_license(license_key)
  78. return license_key
  79. raise ValueError("License key not found")
  80. except Exception as e:
  81. debug.log(f"Error getting license key: {str(e)}")
  82. raise
  83. @classmethod
  84. async def create_async_generator(
  85. cls,
  86. model: str,
  87. messages: Messages,
  88. prompt: str = None,
  89. proxy: str = None,
  90. max_retries: int = 3,
  91. delay: int = 1,
  92. max_tokens: int = None,
  93. **kwargs
  94. ) -> AsyncResult:
  95. if not model:
  96. model = cls.default_model
  97. if model in cls.chat_models:
  98. async for result in cls._generate_text(model, messages, proxy, max_retries, delay, max_tokens):
  99. yield result
  100. elif model in cls.image_models:
  101. prompt = messages[-1]["content"]
  102. async for result in cls._generate_image(model, prompt, proxy):
  103. yield result
  104. else:
  105. raise ValueError(f"Unsupported model: {model}")
  106. @classmethod
  107. async def _generate_text(
  108. cls,
  109. model: str,
  110. messages: Messages,
  111. proxy: str = None,
  112. max_retries: int = 3,
  113. delay: int = 1,
  114. max_tokens: int = None,
  115. ) -> AsyncIterator[str]:
  116. headers = cls._get_headers()
  117. async with ClientSession(headers=headers) as session:
  118. license_key = await cls._get_license_key(session)
  119. api_endpoint = cls.api_endpoints[model]
  120. data = {
  121. "messages": messages,
  122. "max_tokens": max_tokens,
  123. "validated": license_key
  124. }
  125. for attempt in range(max_retries):
  126. try:
  127. async with session.post(api_endpoint, json=data, proxy=proxy) as response:
  128. response.raise_for_status()
  129. response_data = await response.json()
  130. if 'prompt' in response_data:
  131. yield response_data['prompt']
  132. return
  133. else:
  134. raise KeyError("'prompt' key not found in the response")
  135. except Exception as e:
  136. if attempt == max_retries - 1:
  137. raise RuntimeError(f"Error after {max_retries} attempts: {str(e)}")
  138. else:
  139. wait_time = delay * (2 ** attempt) + random.uniform(0, 1)
  140. debug.log(f"Attempt {attempt + 1} failed. Retrying in {wait_time:.2f} seconds...")
  141. await asyncio.sleep(wait_time)
  142. @classmethod
  143. async def _generate_image(
  144. cls,
  145. model: str,
  146. prompt: str,
  147. proxy: str = None
  148. ) -> AsyncIterator[ImageResponse]:
  149. headers = cls._get_headers()
  150. api_endpoint = cls.api_endpoints[model]
  151. async with ClientSession(headers=headers) as session:
  152. data = {
  153. "query": prompt
  154. }
  155. async with session.post(api_endpoint, headers=headers, json=data, proxy=proxy) as response:
  156. response.raise_for_status()
  157. response_data = await response.json()
  158. if 'markdown' in response_data:
  159. image_url = response_data['markdown'].split('(')[1].split(')')[0]
  160. yield ImageResponse(images=image_url, alt=prompt)
  161. @staticmethod
  162. def _get_headers() -> dict:
  163. return {
  164. 'accept': '*/*',
  165. 'accept-language': 'en-US,en;q=0.9',
  166. 'content-type': 'text/plain;charset=UTF-8',
  167. 'origin': 'https://www.blackbox.ai',
  168. 'referer': 'https://www.blackbox.ai',
  169. 'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36'
  170. }