|
|
|
@ -10,9 +10,9 @@ import torch
|
|
|
|
|
from trl.trainer.grpo_trainer import apply_chat_template
|
|
|
|
|
|
|
|
|
|
from config import logger
|
|
|
|
|
from src.deepsearch.prompts import build_user_prompt, get_system_prompt
|
|
|
|
|
from src.deepsearch.search_module import search
|
|
|
|
|
from src.deepsearch.tokenizer_adapter import TokenizerAdapter
|
|
|
|
|
from src.prompts import build_user_prompt, get_system_prompt
|
|
|
|
|
from src.search_module import search
|
|
|
|
|
from src.tokenizer_adapter import TokenizerAdapter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_search_query(text: str) -> str | None:
|
|
|
|
@ -36,9 +36,15 @@ class AgenticOutputs:
|
|
|
|
|
class Agent:
|
|
|
|
|
"""Base agent class for handling tool-based conversations."""
|
|
|
|
|
|
|
|
|
|
def __init__(self, tokenizer_adapter: TokenizerAdapter):
|
|
|
|
|
"""Initialize the agent with a tokenizer adapter."""
|
|
|
|
|
def __init__(self, tokenizer_adapter: TokenizerAdapter, search_fn=None):
|
|
|
|
|
"""Initialize the agent with a tokenizer adapter and optional search function.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
tokenizer_adapter: Tokenizer adapter for handling text
|
|
|
|
|
search_fn: Optional custom search function. If None, uses default search.
|
|
|
|
|
"""
|
|
|
|
|
self.tokenizer_adapter = tokenizer_adapter
|
|
|
|
|
self.search_fn = search_fn or search # Use provided search function or default
|
|
|
|
|
|
|
|
|
|
def get_initial_chat(self, question: str) -> dict:
|
|
|
|
|
"""Initialize a chat state with the question."""
|
|
|
|
@ -113,11 +119,10 @@ class Agent:
|
|
|
|
|
search_query = extract_search_query(assistant_response)
|
|
|
|
|
if search_query:
|
|
|
|
|
logger.info(f"🔍 Search Query: {search_query}")
|
|
|
|
|
results = search(search_query, return_type=str, results=2)
|
|
|
|
|
results = self.search_fn(search_query, return_type=str, results=2)
|
|
|
|
|
formatted_results = f"<information>{results}</information>"
|
|
|
|
|
logger.info(f"ℹ️ Information: {formatted_results}")
|
|
|
|
|
|
|
|
|
|
# chat_state["messages"].append({"role": "ipython", "content": formatted_results})
|
|
|
|
|
chat_state["messages"].append({"role": "user", "content": formatted_results})
|
|
|
|
|
logger.debug("Added search results to chat state")
|
|
|
|
|
except Exception as e:
|