run_tools.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. from __future__ import annotations
  2. import re
  3. import json
  4. import asyncio
  5. import time
  6. from pathlib import Path
  7. from typing import Optional, Callable, AsyncIterator
  8. from ..typing import Messages
  9. from ..providers.helper import filter_none
  10. from ..providers.asyncio import to_async_iterator
  11. from ..providers.response import Reasoning
  12. from ..providers.types import ProviderType
  13. from ..cookies import get_cookies_dir
  14. from .web_search import do_search, get_search_message
  15. from .files import read_bucket, get_bucket_dir
  16. from .. import debug
  17. BUCKET_INSTRUCTIONS = """
  18. Instruction: Make sure to add the sources of cites using [[domain]](Url) notation after the reference. Example: [[a-z0-9.]](http://example.com)
  19. """
  20. def validate_arguments(data: dict) -> dict:
  21. if "arguments" in data:
  22. if isinstance(data["arguments"], str):
  23. data["arguments"] = json.loads(data["arguments"])
  24. if not isinstance(data["arguments"], dict):
  25. raise ValueError("Tool function arguments must be a dictionary or a json string")
  26. else:
  27. return filter_none(**data["arguments"])
  28. else:
  29. return {}
  30. def get_api_key_file(cls) -> Path:
  31. return Path(get_cookies_dir()) / f"api_key_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
  32. async def async_iter_run_tools(provider: ProviderType, model: str, messages, tool_calls: Optional[list] = None, **kwargs):
  33. # Handle web_search from kwargs
  34. web_search = kwargs.get('web_search')
  35. if web_search:
  36. try:
  37. messages = messages.copy()
  38. web_search = web_search if isinstance(web_search, str) and web_search != "true" else None
  39. messages[-1]["content"] = await do_search(messages[-1]["content"], web_search)
  40. except Exception as e:
  41. debug.error(f"Couldn't do web search: {e.__class__.__name__}: {e}")
  42. # Keep web_search in kwargs for provider native support
  43. pass
  44. # Read api_key from config file
  45. if getattr(provider, "needs_auth", False) and "api_key" not in kwargs:
  46. auth_file = get_api_key_file(provider)
  47. if auth_file.exists():
  48. with auth_file.open("r") as f:
  49. auth_result = json.load(f)
  50. if "api_key" in auth_result:
  51. kwargs["api_key"] = auth_result["api_key"]
  52. if tool_calls is not None:
  53. for tool in tool_calls:
  54. if tool.get("type") == "function":
  55. if tool.get("function", {}).get("name") == "search_tool":
  56. tool["function"]["arguments"] = validate_arguments(tool["function"])
  57. messages = messages.copy()
  58. messages[-1]["content"] = await do_search(
  59. messages[-1]["content"],
  60. **tool["function"]["arguments"]
  61. )
  62. elif tool.get("function", {}).get("name") == "continue":
  63. last_line = messages[-1]["content"].strip().splitlines()[-1]
  64. content = f"Carry on from this point:\n{last_line}"
  65. messages.append({"role": "user", "content": content})
  66. elif tool.get("function", {}).get("name") == "bucket_tool":
  67. def on_bucket(match):
  68. return "".join(read_bucket(get_bucket_dir(match.group(1))))
  69. has_bucket = False
  70. for message in messages:
  71. if "content" in message and isinstance(message["content"], str):
  72. new_message_content = re.sub(r'{"bucket_id":"([^"]*)"}', on_bucket, message["content"])
  73. if new_message_content != message["content"]:
  74. has_bucket = True
  75. message["content"] = new_message_content
  76. if has_bucket and isinstance(messages[-1]["content"], str):
  77. if "\nSource: " in messages[-1]["content"]:
  78. messages[-1]["content"] += BUCKET_INSTRUCTIONS
  79. create_function = provider.get_async_create_function()
  80. response = to_async_iterator(create_function(model=model, messages=messages, **kwargs))
  81. async for chunk in response:
  82. yield chunk
  83. def process_thinking_chunk(chunk: str, start_time: float = 0) -> tuple[float, list]:
  84. """Process a thinking chunk and return timing and results."""
  85. results = []
  86. # Handle non-thinking chunk
  87. if not start_time and "<think>" not in chunk:
  88. return 0, [chunk]
  89. # Handle thinking start
  90. if "<think>" in chunk and not "`<think>`" in chunk:
  91. before_think, *after = chunk.split("<think>", 1)
  92. if before_think:
  93. results.append(before_think)
  94. results.append(Reasoning(status="🤔 Is thinking...", is_thinking="<think>"))
  95. if after and after[0]:
  96. results.append(Reasoning(after[0]))
  97. return time.time(), results
  98. # Handle thinking end
  99. if "</think>" in chunk:
  100. before_end, *after = chunk.split("</think>", 1)
  101. if before_end:
  102. results.append(Reasoning(before_end))
  103. thinking_duration = time.time() - start_time if start_time > 0 else 0
  104. status = f"Thought for {thinking_duration:.2f}s" if thinking_duration > 1 else "Finished"
  105. results.append(Reasoning(status=status, is_thinking="</think>"))
  106. if after and after[0]:
  107. results.append(after[0])
  108. return 0, results
  109. # Handle ongoing thinking
  110. if start_time:
  111. return start_time, [Reasoning(chunk)]
  112. return start_time, [chunk]
  113. def iter_run_tools(
  114. iter_callback: Callable,
  115. model: str,
  116. messages: Messages,
  117. provider: Optional[str] = None,
  118. tool_calls: Optional[list] = None,
  119. **kwargs
  120. ) -> AsyncIterator:
  121. # Handle web_search from kwargs
  122. web_search = kwargs.get('web_search')
  123. if web_search:
  124. try:
  125. messages = messages.copy()
  126. web_search = web_search if isinstance(web_search, str) and web_search != "true" else None
  127. messages[-1]["content"] = asyncio.run(do_search(messages[-1]["content"], web_search))
  128. except Exception as e:
  129. debug.error(f"Couldn't do web search: {e.__class__.__name__}: {e}")
  130. # Keep web_search in kwargs for provider native support
  131. pass
  132. # Read api_key from config file
  133. if provider is not None and provider.needs_auth and "api_key" not in kwargs:
  134. auth_file = get_api_key_file(provider)
  135. if auth_file.exists():
  136. with auth_file.open("r") as f:
  137. auth_result = json.load(f)
  138. if "api_key" in auth_result:
  139. kwargs["api_key"] = auth_result["api_key"]
  140. if tool_calls is not None:
  141. for tool in tool_calls:
  142. if tool.get("type") == "function":
  143. if tool.get("function", {}).get("name") == "search_tool":
  144. tool["function"]["arguments"] = validate_arguments(tool["function"])
  145. messages[-1]["content"] = get_search_message(
  146. messages[-1]["content"],
  147. raise_search_exceptions=True,
  148. **tool["function"]["arguments"]
  149. )
  150. elif tool.get("function", {}).get("name") == "continue_tool":
  151. if provider not in ("OpenaiAccount", "HuggingFace"):
  152. last_line = messages[-1]["content"].strip().splitlines()[-1]
  153. content = f"Carry on from this point:\n{last_line}"
  154. messages.append({"role": "user", "content": content})
  155. else:
  156. # Enable provider native continue
  157. if "action" not in kwargs:
  158. kwargs["action"] = "continue"
  159. elif tool.get("function", {}).get("name") == "bucket_tool":
  160. def on_bucket(match):
  161. return "".join(read_bucket(get_bucket_dir(match.group(1))))
  162. has_bucket = False
  163. for message in messages:
  164. if "content" in message and isinstance(message["content"], str):
  165. new_message_content = re.sub(r'{"bucket_id":"([^"]*)"}', on_bucket, message["content"])
  166. if new_message_content != message["content"]:
  167. has_bucket = True
  168. message["content"] = new_message_content
  169. if has_bucket and isinstance(messages[-1]["content"], str):
  170. if "\nSource: " in messages[-1]["content"]:
  171. messages[-1]["content"] = messages[-1]["content"]["content"] + BUCKET_INSTRUCTIONS
  172. thinking_start_time = 0
  173. for chunk in iter_callback(model=model, messages=messages, provider=provider, **kwargs):
  174. if not isinstance(chunk, str):
  175. yield chunk
  176. continue
  177. thinking_start_time, results = process_thinking_chunk(chunk, thinking_start_time)
  178. for result in results:
  179. yield result