DDG.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. from __future__ import annotations
  2. import time
  3. from aiohttp import ClientSession, ClientTimeout
  4. import json
  5. import asyncio
  6. import random
  7. from ..typing import AsyncResult, Messages, Cookies
  8. from ..requests.raise_for_status import raise_for_status
  9. from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
  10. from .helper import format_prompt
  11. from ..providers.response import FinishReason, JsonConversation
  12. from ..errors import ModelNotSupportedError, ResponseStatusError, RateLimitError, TimeoutError, ConversationLimitError
  13. class DuckDuckGoSearchException(Exception):
  14. """Base exception class for duckduckgo_search."""
  15. class Conversation(JsonConversation):
  16. vqd: str = None
  17. message_history: Messages = []
  18. cookies: dict = {}
  19. def __init__(self, model: str):
  20. self.model = model
  21. class DDG(AsyncGeneratorProvider, ProviderModelMixin):
  22. label = "DuckDuckGo AI Chat"
  23. url = "https://duckduckgo.com/aichat"
  24. api_endpoint = "https://duckduckgo.com/duckchat/v1/chat"
  25. status_url = "https://duckduckgo.com/duckchat/v1/status"
  26. working = True
  27. supports_stream = True
  28. supports_system_message = True
  29. supports_message_history = True
  30. default_model = "gpt-4o-mini"
  31. models = [default_model, "claude-3-haiku-20240307", "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", "mistralai/Mixtral-8x7B-Instruct-v0.1"]
  32. model_aliases = {
  33. "gpt-4": "gpt-4o-mini",
  34. "claude-3-haiku": "claude-3-haiku-20240307",
  35. "llama-3.1-70b": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
  36. "mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1",
  37. }
  38. last_request_time = 0
  39. @classmethod
  40. def validate_model(cls, model: str) -> str:
  41. """Validates and returns the correct model name"""
  42. if not model:
  43. return cls.default_model
  44. if model in cls.model_aliases:
  45. model = cls.model_aliases[model]
  46. if model not in cls.models:
  47. raise ModelNotSupportedError(f"Model {model} not supported. Available models: {cls.models}")
  48. return model
  49. @classmethod
  50. async def sleep(cls):
  51. """Implements rate limiting between requests"""
  52. now = time.time()
  53. if cls.last_request_time > 0:
  54. delay = max(0.0, 0.75 - (now - cls.last_request_time))
  55. if delay > 0:
  56. await asyncio.sleep(delay)
  57. cls.last_request_time = now
  58. @classmethod
  59. async def fetch_vqd(cls, session: ClientSession, max_retries: int = 3) -> str:
  60. """Fetches the required VQD token for the chat session with retries."""
  61. headers = {
  62. "accept": "text/event-stream",
  63. "content-type": "application/json",
  64. "x-vqd-accept": "1",
  65. "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
  66. }
  67. for attempt in range(max_retries):
  68. try:
  69. await cls.sleep()
  70. async with session.get(cls.status_url, headers=headers) as response:
  71. await raise_for_status(response)
  72. vqd = response.headers.get("x-vqd-4", "")
  73. if vqd:
  74. return vqd
  75. response_text = await response.text()
  76. raise RuntimeError(f"Failed to fetch VQD token: {response.status} {response_text}")
  77. except ResponseStatusError as e:
  78. if attempt < max_retries - 1:
  79. wait_time = random.uniform(1, 3) * (attempt + 1)
  80. await asyncio.sleep(wait_time)
  81. else:
  82. raise RuntimeError(f"Failed to fetch VQD token after {max_retries} attempts: {str(e)}")
  83. raise RuntimeError("Failed to fetch VQD token: Maximum retries exceeded")
  84. @classmethod
  85. async def create_async_generator(
  86. cls,
  87. model: str,
  88. messages: Messages,
  89. proxy: str = None,
  90. timeout: int = 30,
  91. cookies: Cookies = None,
  92. conversation: Conversation = None,
  93. return_conversation: bool = False,
  94. **kwargs
  95. ) -> AsyncResult:
  96. model = cls.validate_model(model)
  97. if cookies is None and conversation is not None:
  98. cookies = conversation.cookies
  99. try:
  100. async with ClientSession(timeout=ClientTimeout(total=timeout), cookies=cookies) as session:
  101. if conversation is None:
  102. conversation = Conversation(model)
  103. conversation.vqd = await cls.fetch_vqd(session)
  104. conversation.message_history = [{"role": "user", "content": format_prompt(messages)}]
  105. else:
  106. conversation.message_history.append(messages[-1])
  107. headers = {
  108. "accept": "text/event-stream",
  109. "content-type": "application/json",
  110. "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36",
  111. "x-vqd-4": conversation.vqd,
  112. }
  113. data = {
  114. "model": model,
  115. "messages": conversation.message_history,
  116. }
  117. await cls.sleep()
  118. async with session.post(cls.api_endpoint, json=data, headers=headers, proxy=proxy) as response:
  119. await raise_for_status(response)
  120. reason = None
  121. full_message = ""
  122. async for line in response.content:
  123. line = line.decode("utf-8").strip()
  124. if line.startswith("data:"):
  125. try:
  126. message = json.loads(line[5:].strip())
  127. except json.JSONDecodeError:
  128. continue
  129. if "action" in message and message["action"] == "error":
  130. error_type = message.get("type", "")
  131. if message.get("status") == 429:
  132. if error_type == "ERR_CONVERSATION_LIMIT":
  133. raise ConversationLimitError(error_type)
  134. raise RateLimitError(error_type)
  135. raise DuckDuckGoSearchException(error_type)
  136. if "message" in message:
  137. if message["message"]:
  138. yield message["message"]
  139. full_message += message["message"]
  140. reason = "length"
  141. else:
  142. reason = "stop"
  143. if return_conversation:
  144. conversation.message_history.append({"role": "assistant", "content": full_message})
  145. conversation.vqd = response.headers.get("x-vqd-4", conversation.vqd)
  146. conversation.cookies = {
  147. n: c.value
  148. for n, c in session.cookie_jar.filter_cookies(cls.url).items()
  149. }
  150. yield conversation
  151. if reason is not None:
  152. yield FinishReason(reason)
  153. except asyncio.TimeoutError as e:
  154. raise TimeoutError(f"Request timed out: {str(e)}")