You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

657 lines
28 KiB

import argparse
import json
import os
import re
import time
from copy import deepcopy
from datetime import datetime
from functools import wraps
import numpy as np
import requests
from flashrag.config import Config
from flashrag.generator.generator import BaseGenerator
from flashrag.pipeline import BasicPipeline
from flashrag.retriever.retriever import BaseTextRetriever
from flashrag.utils import get_dataset
from transformers import AutoTokenizer
from config import logger
from src.agent import Agent, AgenticOutputs
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
def retry(max_retries=10, sleep=1):
"""Decorator to retry a function with exponential backoff."""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
func_name = func.__name__
for attempt in range(max_retries):
try:
result = func(*args, **kwargs)
return result
except Exception as e:
logger.warning(f"Attempt {attempt + 1} of {func_name} failed: {e}")
if attempt == max_retries - 1:
logger.error(f"Function {func_name} failed after {max_retries} retries.", exc_info=True)
raise e
backoff_time = sleep * (2**attempt)
logger.info(f"Retrying {func_name} in {backoff_time:.2f} seconds...")
time.sleep(backoff_time)
logger.error(f"Function {func_name} retry logic finished unexpectedly.")
return None
return wrapper
return decorator
class RemoteRetriever(BaseTextRetriever):
"""A wrapper for remote retriever service with retry logic and logging."""
def __init__(self, config: Config):
"""Initializes the RemoteRetriever."""
super().__init__(config)
self.remote_url = f"http://{getattr(config, 'remote_retriever_url', 'localhost:8001')}"
self.topk = getattr(config, "retriever_topk", 5)
logger.info(f"🔗 Remote retriever URL: {self.remote_url}")
@retry(max_retries=3, sleep=2)
def _search(self, query: str, num: int | None = None, return_score: bool = False) -> list[dict]:
"""Search for documents using the remote retriever service."""
num = num if num is not None else self.topk
url = f"{self.remote_url}/search"
try:
response = requests.post(
url,
json={"query": query, "top_n": num, "return_score": return_score},
timeout=30,
)
response.raise_for_status()
results = response.json()
return results
except requests.exceptions.Timeout:
logger.error(f"Search request timed out after 30 seconds for query: {query[:50]}...")
raise
except requests.exceptions.ConnectionError:
logger.error(f"Could not connect to search service at {url}")
raise
except requests.exceptions.RequestException as e:
logger.error(f"Search request failed: {e}", exc_info=True)
raise
except Exception as e:
logger.error(f"Unexpected search error: {str(e)}", exc_info=True)
raise
@retry(max_retries=3, sleep=2)
def _batch_search(
self, queries: list[str], num: int | None = None, return_score: bool = False
) -> list[list[dict]]:
"""Batch search for documents using the remote retriever service."""
num = num if num is not None else self.topk
url = f"{self.remote_url}/batch_search"
try:
response = requests.post(
url,
json={"query": queries, "top_n": num, "return_score": return_score},
timeout=60,
)
response.raise_for_status()
results = response.json()
return results
except requests.exceptions.Timeout:
logger.error(f"Batch search request timed out after 60 seconds for {len(queries)} queries.")
raise
except requests.exceptions.ConnectionError:
logger.error(f"Could not connect to batch search service at {url}")
raise
except requests.exceptions.RequestException as e:
logger.error(f"Batch search request failed: {e}", exc_info=True)
raise
except Exception as e:
logger.error(f"Unexpected batch search error: {str(e)}", exc_info=True)
raise
class ReSearchPipeline(BasicPipeline):
"""Pipeline for ReSearch method using Agent for generation and tool use."""
def __init__(
self, config: Config, retriever: BaseTextRetriever | None = None, generator: BaseGenerator | None = None
):
"""Initializes the ReSearchPipeline."""
super().__init__(config)
logger.info("🔧 Initializing ReSearchPipeline...")
self.retriever = retriever or RemoteRetriever(config)
self.generator = generator or SGLRemoteGenerator(config)
try:
self.tokenizer = AutoTokenizer.from_pretrained(config.generator_model_path, trust_remote_code=True)
if not self.tokenizer.pad_token:
logger.warning("Tokenizer does not have a pad token; setting to eos_token.")
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.padding_side = "left"
logger.info("✅ Tokenizer initialized.")
except Exception as e:
logger.error(f"Failed to initialize tokenizer: {e}", exc_info=True)
raise
tokenizer_name = self.tokenizer.name_or_path.lower()
if "deepseek-ai/deepseek-r1-distill" in tokenizer_name:
adapter = R1DistilTokenizerAdapter()
elif "llama" in tokenizer_name:
adapter = LlamaTokenizerAdapter()
else:
logger.warning(f"Unknown tokenizer type '{tokenizer_name}', defaulting to R1DistilTokenizerAdapter.")
adapter = R1DistilTokenizerAdapter()
logger.info(f"🔩 Using Tokenizer Adapter: {type(adapter).__name__}")
def retriever_search(query: str, return_type=str, results: int = 5):
try:
search_results = self.retriever._search(query, num=results)
return self.format_search_results(search_results)
except Exception as e:
logger.error(f"Error during agent's retriever search for query '{query[:50]}...': {e}", exc_info=True)
return "<information>Search failed due to an internal error."
self.agent = Agent(adapter, search_fn=retriever_search)
logger.info("✅ Agent initialized.")
logger.info("✅ ReSearchPipeline initialized successfully.")
def format_search_results(self, search_results: list[dict]) -> str:
"""Formats search results into a string for the agent prompt."""
if not search_results:
return "<information>No results found.</information>"
max_content_len = 500
formatted = "\n-------\n".join(
[
f"Result {i + 1}: {r.get('contents', 'N/A')[:max_content_len]}{'...' if len(r.get('contents', '')) > max_content_len else ''}"
for i, r in enumerate(search_results)
]
)
formatted_str = f"<information>{formatted}</information>"
return formatted_str
def extract_search_query(self, text: str) -> str | None:
"""Extract search query from text between <search> tags."""
pattern = re.compile(r"<search>(.*?)</search>", re.DOTALL)
matches = pattern.findall(text)
if matches:
query = matches[-1].strip()
return query
return None
def extract_answer(self, text: str) -> str | None:
"""Extract answer from text between <answer> tags."""
pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL)
matches = pattern.findall(text)
if matches:
answer = matches[-1].strip()
return answer
return None
def run(self, dataset, do_eval: bool = True, pred_process_fun=None):
"""Runs the ReSearch pipeline on the dataset using the Agent."""
logger.info(f"🏃 Starting ReSearch pipeline run with {len(dataset)} items...")
try:
questions = [item.question if hasattr(item, "question") else item["question"] for item in dataset]
except (KeyError, AttributeError, TypeError) as e:
logger.error(f"Failed to extract questions from dataset items. Error: {e}", exc_info=True)
logger.error("Ensure dataset items have a 'question' key or attribute.")
return dataset
agent_max_generations = getattr(self.config, "agent_max_generations", 32)
generator_max_output_len = getattr(self.config, "generator_max_output_len", 24576)
try:
logger.info(f"🤖 Running agent inference for {len(questions)} questions...")
agent_outputs: AgenticOutputs = self.agent.run_agent(
generate_fn=self.generator.generate,
tokenizer=self.tokenizer,
questions=questions,
max_generations=agent_max_generations,
max_new_tokens=generator_max_output_len,
)
final_responses = agent_outputs.final_response_str
logger.info(f"✅ Agent inference completed. Received {len(final_responses)} final responses.")
except Exception as e:
logger.error(f"Agent run failed during inference: {e}", exc_info=True)
logger.warning("Agent run failed, attempting evaluation with potentially incomplete results.")
for item in dataset:
if hasattr(item, "update_output"):
item.update_output("pred", "AGENT_ERROR")
elif isinstance(item, dict):
item["pred"] = "AGENT_ERROR"
logger.info("📝 Extracting answers and updating dataset items...")
num_updated = 0
num_missing_answers = 0
if len(final_responses) == len(dataset):
for i, item in enumerate(dataset):
response = final_responses[i]
answer = self.extract_answer(response)
pred_to_save = answer if answer is not None else ""
if answer is None:
num_missing_answers += 1
if hasattr(item, "update_output"):
item.update_output("pred", pred_to_save)
item.update_output("final_response", response)
num_updated += 1
elif isinstance(item, dict):
item["pred"] = pred_to_save
item["final_response"] = response
num_updated += 1
else:
logger.warning(f"Item {i} has unknown type {type(item)}, cannot update with prediction.")
logger.info(f"Updated {num_updated}/{len(dataset)} dataset items with predictions.")
if num_missing_answers > 0:
logger.warning(f"{num_missing_answers} items had no <answer> tag.")
else:
logger.error(
f"Mismatch between dataset size ({len(dataset)}) and number of agent responses ({len(final_responses)}). Cannot reliably update dataset."
)
for item in dataset:
if hasattr(item, "update_output"):
item.update_output("pred", "RESPONSE_COUNT_MISMATCH")
elif isinstance(item, dict):
item["pred"] = "RESPONSE_COUNT_MISMATCH"
if do_eval:
logger.info("📊 Evaluating results using BasicPipeline.evaluate...")
try:
dataset = self.evaluate(dataset, do_eval=True, pred_process_fun=pred_process_fun)
logger.info("✅ Evaluation completed via base class method.")
except Exception as e:
logger.error(f"Error during BasicPipeline.evaluate: {e}", exc_info=True)
logger.warning("Evaluation may be incomplete.")
else:
logger.info("Skipping evaluation step as do_eval=False.")
logger.info("✅ ReSearch pipeline run finished.")
return dataset
class SGLRemoteGenerator(BaseGenerator):
"""Class for decoder-only generator, based on SGLang remote service."""
def __init__(self, config: Config):
"""Initializes the SGLRemoteGenerator."""
super().__init__(config)
logger.info("🔧 Initializing SGLRemoteGenerator...")
sgl_url = getattr(config, "sgl_remote_url", "localhost:8002")
self.sgl_remote_url = f"http://{sgl_url}/generate"
self.health_check_url = f"http://{sgl_url}/health"
logger.info(f"🔗 Remote Generator URL: {self.sgl_remote_url}")
self.model_path = getattr(config, "generator_model_path", None)
if not self.model_path:
logger.error("generator_model_path not found in config!")
raise ValueError("generator_model_path is required for SGLRemoteGenerator")
try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
logger.info("✅ Tokenizer loaded for generator.")
except Exception as e:
logger.error(f"Failed to load tokenizer for generator from {self.model_path}: {e}", exc_info=True)
raise
self.generation_params = getattr(config, "generation_params", {})
self.config = config
self._check_health()
def _check_health(self):
"""Checks the health of the remote generator service."""
try:
test_response = requests.get(self.health_check_url, timeout=5)
test_response.raise_for_status()
logger.info("✅ Remote generator service is available")
except requests.exceptions.RequestException as e:
logger.error(f"Could not connect or verify remote generator service at {self.health_check_url}: {str(e)}")
logger.warning("Please ensure the SGLang service is running and accessible.")
@retry(max_retries=5, sleep=2)
def generate(
self,
input_list: list[str] | str,
return_raw_output: bool = False,
return_scores: bool = False,
**params,
) -> list[str] | tuple[list[str], list[list[float]]] | list[dict]:
"""Generates text using the remote SGLang service."""
if isinstance(input_list, str):
input_list = [input_list]
if not isinstance(input_list, list) or not all(isinstance(item, str) for item in input_list):
raise ValueError("Input must be a string or a list of strings.")
batch_size = len(input_list)
data_to_remote = {"text": input_list}
effective_params = deepcopy(self.generation_params)
effective_params.update(params)
curr_sampling_params = {}
if effective_params.get("do_sample", True) is False:
curr_sampling_params["temperature"] = 0.0
else:
curr_sampling_params["temperature"] = effective_params.get(
"temperature", getattr(self.config, "temperature", 0.7)
)
default_max_tokens = getattr(self.config, "generator_max_output_len", 1024)
curr_sampling_params["max_new_tokens"] = effective_params.get("max_new_tokens", default_max_tokens)
stop_sequences = effective_params.get("stop", [])
if isinstance(stop_sequences, str):
stop_sequences = [stop_sequences]
if stop_sequences:
curr_sampling_params["stop"] = stop_sequences
keys_to_remove = ["do_sample", "temperature", "max_new_tokens", "stop"]
for key in keys_to_remove:
effective_params.pop(key, None)
if "top_p" in effective_params:
curr_sampling_params["top_p"] = effective_params["top_p"]
if "top_k" in effective_params:
curr_sampling_params["top_k"] = effective_params["top_k"]
data_to_remote["sampling_params"] = curr_sampling_params
if return_scores:
data_to_remote["return_logprob"] = True
data_to_remote["top_logprobs_num"] = getattr(self.config, "top_logprobs_num", 2)
try:
response = requests.post(
self.sgl_remote_url, json=data_to_remote, timeout=120, headers={"Content-Type": "application/json"}
)
response.raise_for_status()
response_list = response.json()
if return_raw_output:
return response_list
generated_text = []
for item in response_list:
text = item.get("text", "")
finish_reason = item.get("meta_info", {}).get("finish_reason", {})
matched_stop = finish_reason.get("matched")
if matched_stop and curr_sampling_params.get("stop") and matched_stop in curr_sampling_params["stop"]:
text += matched_stop
generated_text.append(text)
if return_scores:
scores = []
for resp_item in response_list:
logprobs_list = resp_item.get("meta_info", {}).get("output_token_logprobs", [])
token_scores = [
np.exp(logprob[0]) if (logprob and len(logprob) > 0) else 0.0 for logprob in logprobs_list
]
scores.append(token_scores)
return generated_text, scores
else:
return generated_text
except requests.exceptions.Timeout:
logger.error("Generation request timed out after 120 seconds.")
raise
except requests.exceptions.ConnectionError:
logger.error(f"Could not connect to remote generator service at {self.sgl_remote_url}.")
raise
except requests.exceptions.RequestException as e:
logger.error(f"Network error during generation: {str(e)}", exc_info=True)
raise
except json.JSONDecodeError:
response_text = "Unknown (error occurred before response object assignment)"
if "response" in locals() and hasattr(response, "text"):
response_text = response.text[:500]
logger.error(
f"Failed to decode JSON response from {self.sgl_remote_url}. Response text: {response_text}...",
exc_info=True,
)
raise
except Exception as e:
logger.error(f"Unexpected error during generation: {str(e)}", exc_info=True)
raise
def load_dataset_items(config: Config, split: str) -> list[dict | object]:
"""Loads dataset items using flashrag's get_dataset."""
logger.info(f"📚 Loading dataset: {config.dataset_name}, Split: {split}")
try:
all_splits = get_dataset(config)
if split not in all_splits:
logger.error(
f"Split '{split}' not found in dataset '{config.dataset_name}'. Available splits: {list(all_splits.keys())}"
)
return []
dataset_items = all_splits[split]
logger.info(f"Successfully loaded {len(dataset_items)} items for split '{split}'.")
return dataset_items
except FileNotFoundError:
logger.error(
f"Dataset files not found for '{config.dataset_name}' in '{config.data_dir}'. Check config and paths."
)
return []
except Exception as e:
logger.error(f"Error loading dataset using get_dataset: {e}", exc_info=True)
return []
def save_results(args: argparse.Namespace, config: Config, result_dataset, run_duration: float):
"""Saves summary and debug information."""
logger.info("💾 Saving results...")
summary_file = os.path.join(args.save_dir, f"{args.save_note}_summary.txt")
debug_file = os.path.join(args.save_dir, f"{args.save_note}_debug.json")
num_items = len(result_dataset)
logger.info(f"Saving summary results to {summary_file}...")
try:
with open(summary_file, "w", encoding="utf-8") as f:
f.write("EVALUATION SUMMARY\n")
f.write("=================\n\n")
f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"Run Duration: {run_duration:.2f} seconds\n")
f.write(f"Dataset: {config.dataset_name} ({args.split} split)\n")
f.write(f"Model: {config.generator_model_path}\n")
f.write(f"Retriever: {config.remote_retriever_url}\n")
f.write(f"Agent Max Generations: {getattr(config, 'agent_max_generations', 'N/A')}\n")
f.write(f"Generator Max Output Len: {getattr(config, 'generator_max_output_len', 'N/A')}\n\n")
f.write(f"Total items processed: {num_items}\n")
f.write("\nNote: Verification was skipped in this run.\n")
f.write("Note: Overall metrics (like EM, F1) are usually printed to console by evaluate method.\n")
logger.info(f"✅ Summary saved to {summary_file}")
except Exception as e:
logger.error(f"Error saving summary file '{summary_file}': {e}", exc_info=True)
logger.info(f"Saving debug information (predictions & responses) to {debug_file}...")
try:
debug_data = []
for i, item in enumerate(result_dataset):
item_data: dict[str, object] = {}
def get_item_value(data_item, key_or_attr: str) -> str | int | float | list | bool | None:
if isinstance(data_item, dict):
return data_item.get(key_or_attr)
elif hasattr(data_item, key_or_attr):
return getattr(data_item, key_or_attr)
return None
item_data["item_index"] = i
item_data["question"] = get_item_value(item, "question")
item_data["prediction"] = get_item_value(item, "pred")
item_data["final_response"] = get_item_value(item, "final_response")
gt_answer_val = None
try:
gt_answer_val = get_item_value(item, "answer")
if gt_answer_val is None:
answers_list = get_item_value(item, "answers")
if isinstance(answers_list, list) and answers_list:
raw_ans = answers_list[0]
if isinstance(raw_ans, (str, int, float, bool)):
gt_answer_val = raw_ans
else:
gt_answer_val = str(raw_ans)
elif not isinstance(gt_answer_val, (str, int, float, bool)):
gt_answer_val = str(gt_answer_val)
except Exception as e:
logger.warning(f"Could not safely get ground truth for item {i}: {e}")
gt_answer_val = "ERROR_GETTING_ANSWER"
item_data["ground_truth"] = gt_answer_val
eval_score_val = None
try:
eval_score_val = get_item_value(item, "score")
if not isinstance(eval_score_val, (str, int, float, bool, type(None))):
eval_score_val = str(eval_score_val)
except Exception as e:
logger.warning(f"Could not safely get score for item {i}: {e}")
eval_score_val = "ERROR_GETTING_SCORE"
item_data["eval_score"] = eval_score_val
debug_data.append(item_data)
with open(debug_file, "w", encoding="utf-8") as f:
json.dump(debug_data, f, indent=2, ensure_ascii=False)
logger.info(f"✅ Debug information saved to {debug_file}")
except Exception as e:
logger.error(f"Error saving debug file '{debug_file}': {e}", exc_info=True)
def research(args: argparse.Namespace, config: Config):
"""Main function to run the research evaluation pipeline."""
logger.info("🚀 Starting research pipeline execution...")
start_time = time.time()
test_data = load_dataset_items(config, args.split)
if not test_data:
logger.error("Failed to load test data. Exiting.")
return
try:
logger.info("🏗️ Building ReSearchPipeline...")
pipeline = ReSearchPipeline(config)
logger.info("✅ Pipeline built successfully.")
except Exception as e:
logger.error(f"Failed to initialize ReSearchPipeline: {e}", exc_info=True)
return
agent_max_generations = getattr(config, "agent_max_generations", 32)
generator_max_output_len = getattr(config, "generator_max_output_len", 24576)
try:
logger.info("🏃 Starting pipeline run...")
result_dataset = pipeline.run(test_data, do_eval=True)
logger.info("✅ Pipeline run completed.")
except Exception as e:
logger.error(f"Error during pipeline run: {e}", exc_info=True)
result_dataset = test_data
logger.warning("Pipeline run failed, attempting to save inputs/partial results.")
run_duration = time.time() - start_time
logger.info(f"Total run duration: {run_duration:.2f} seconds.")
save_results(args, config, result_dataset, run_duration)
logger.info("🏁 Research pipeline execution finished.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Running ReSearch Evaluation Pipeline")
parser.add_argument(
"--config_path", type=str, default="./eval_config.yaml", help="Path to the main FlashRAG config file."
)
parser.add_argument(
"--dataset_name",
type=str,
default="bamboogle",
help="Name of the dataset (must match config or data_dir structure).",
)
parser.add_argument(
"--split", type=str, default="test", help="Dataset split to evaluate (e.g., test, validation)."
)
parser.add_argument("--save_dir", type=str, default="./output_logs", help="Directory to save logs and results.")
parser.add_argument("--save_note", type=str, default="research_run", help="A note to prepend to saved filenames.")
parser.add_argument("--data_dir", type=str, help="Override data directory specified in config.")
parser.add_argument(
"--sgl_remote_url", type=str, help="Override SGLang remote generator URL (e.g., localhost:8002)."
)
parser.add_argument(
"--remote_retriever_url", type=str, help="Override remote retriever URL (e.g., localhost:8001)."
)
parser.add_argument("--generator_model_path", type=str, help="Override generator model path specified in config.")
parser.add_argument("--retriever_topk", type=int, help="Override retriever top K.")
parser.add_argument("--generator_max_output_len", type=int, help="Override generator max output length.")
parser.add_argument("--agent_max_generations", type=int, help="Override agent max interaction turns.")
args = parser.parse_args()
logger.info(f"Starting evaluation script with arguments: {args}")
try:
os.makedirs(args.save_dir, exist_ok=True)
logger.info(f"💾 Logs and results will be saved to: {args.save_dir}")
except OSError as e:
logger.error(f"Could not create save directory '{args.save_dir}': {e}", exc_info=True)
exit(1)
config_overrides = {
k: v
for k, v in vars(args).items()
if v is not None
and k
not in [
"config_path",
"dataset_name",
"split",
"save_dir",
"save_note",
]
}
logger.info(f"🔧 Loading configuration from: {args.config_path}")
try:
config = Config(args.config_path, config_dict=config_overrides)
config.dataset_name = args.dataset_name
if args.data_dir:
config.data_dir = args.data_dir
logger.info(f"Effective data_dir: {getattr(config, 'data_dir', 'N/A')}")
logger.info(f"Effective generator_model_path: {getattr(config, 'generator_model_path', 'N/A')}")
logger.info(f"Effective sgl_remote_url: {getattr(config, 'sgl_remote_url', 'N/A')}")
logger.info(f"Effective remote_retriever_url: {getattr(config, 'remote_retriever_url', 'N/A')}")
logger.info("✅ Config loaded and potentially overridden by CLI arguments.")
config["dataset_path"] = os.path.join(config.data_dir, config.dataset_name)
except FileNotFoundError:
logger.error(f"Config file not found at '{args.config_path}'. Please check the path.")
exit(1)
except Exception as e:
logger.error(f"Error loading or processing configuration: {e}", exc_info=True)
exit(1)
research(args, config)