|
|
|
@ -3,7 +3,6 @@ import json
|
|
|
|
|
import os
|
|
|
|
|
import re
|
|
|
|
|
import time
|
|
|
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
|
from copy import deepcopy
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
from functools import wraps
|
|
|
|
@ -15,12 +14,10 @@ from flashrag.generator.generator import BaseGenerator
|
|
|
|
|
from flashrag.pipeline import BasicPipeline
|
|
|
|
|
from flashrag.retriever.retriever import BaseTextRetriever
|
|
|
|
|
from flashrag.utils import get_dataset
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
|
|
|
|
from config import logger
|
|
|
|
|
from src.agent import Agent, AgenticOutputs
|
|
|
|
|
from src.prompts import build_user_prompt, get_system_prompt
|
|
|
|
|
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -157,7 +154,7 @@ class ReSearchPipeline(BasicPipeline):
|
|
|
|
|
adapter = R1DistilTokenizerAdapter()
|
|
|
|
|
logger.info(f"🔩 Using Tokenizer Adapter: {type(adapter).__name__}")
|
|
|
|
|
|
|
|
|
|
def retriever_search(query: str, return_type=str, results: int = 2):
|
|
|
|
|
def retriever_search(query: str, return_type=str, results: int = 5):
|
|
|
|
|
try:
|
|
|
|
|
search_results = self.retriever._search(query, num=results)
|
|
|
|
|
return self.format_search_results(search_results)
|
|
|
|
@ -216,8 +213,8 @@ class ReSearchPipeline(BasicPipeline):
|
|
|
|
|
logger.error("Ensure dataset items have a 'question' key or attribute.")
|
|
|
|
|
return dataset
|
|
|
|
|
|
|
|
|
|
agent_max_generations = getattr(self.config, "agent_max_generations", 20)
|
|
|
|
|
generator_max_output_len = getattr(self.config, "generator_max_output_len", 1024)
|
|
|
|
|
agent_max_generations = getattr(self.config, "agent_max_generations", 32)
|
|
|
|
|
generator_max_output_len = getattr(self.config, "generator_max_output_len", 24576)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
logger.info(f"🤖 Running agent inference for {len(questions)} questions...")
|
|
|
|
@ -422,7 +419,7 @@ class SGLRemoteGenerator(BaseGenerator):
|
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
|
|
|
logger.error(f"Network error during generation: {str(e)}", exc_info=True)
|
|
|
|
|
raise
|
|
|
|
|
except json.JSONDecodeError as e:
|
|
|
|
|
except json.JSONDecodeError:
|
|
|
|
|
response_text = "Unknown (error occurred before response object assignment)"
|
|
|
|
|
if "response" in locals() and hasattr(response, "text"):
|
|
|
|
|
response_text = response.text[:500]
|
|
|
|
@ -561,8 +558,8 @@ def research(args: argparse.Namespace, config: Config):
|
|
|
|
|
logger.error(f"Failed to initialize ReSearchPipeline: {e}", exc_info=True)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
agent_max_generations = getattr(config, "agent_max_generations", 20)
|
|
|
|
|
generator_max_output_len = getattr(config, "generator_max_output_len", 1024)
|
|
|
|
|
agent_max_generations = getattr(config, "agent_max_generations", 32)
|
|
|
|
|
generator_max_output_len = getattr(config, "generator_max_output_len", 24576)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
logger.info("🏃 Starting pipeline run...")
|
|
|
|
|