feat: [WIP] add bench scripts

main
thinhlpg 4 weeks ago
parent bd02305efb
commit 14ef79a4f5

1
.gitignore vendored

@ -18,6 +18,7 @@ logs/
*.code-workspace *.code-workspace
data/ data/
.gradio/ .gradio/
output_*
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/

@ -38,4 +38,5 @@ dependencies = [
"requests>=2.31.0", "requests>=2.31.0",
"tqdm>=4.66.1", "tqdm>=4.66.1",
"tavily-python", "tavily-python",
"sglang[all]>=0.4.5",
] ]

@ -0,0 +1,46 @@
# ------------------------------------------------Environment Settings------------------------------------------------#
# Directory paths for data and outputs
data_dir: "/mnt/nas/thinhlpg/code/DeepSearch/data/flashrag_datasets/"
save_dir: "/mnt/nas/thinhlpg/code/DeepSearch/logs"
# Seed for reproducibility
seed: 2024
# Whether save intermediate data
save_intermediate_data: True
save_note: 'experiment'
# -------------------------------------------------Retrieval Settings------------------------------------------------#
# If set the remote url, the retriever will be a remote retriever and ignore following settings
use_remote_retriever: True
remote_retriever_url: "localhost:8001"
instruction: ~ # instruction for retrieval model
retrieval_topk: 5 # number of retrieved documents
retrieval_batch_size: 256 # batch size for retrieval
retrieval_use_fp16: True # whether to use fp16 for retrieval model
retrieval_query_max_length: 128 # max length of the query
save_retrieval_cache: False # whether to save the retrieval cache
use_retrieval_cache: False # whether to use the retrieval cache
retrieval_cache_path: ~ # path to the retrieval cache
retrieval_pooling_method: ~ # set automatically if not provided
# -------------------------------------------------Generator Settings------------------------------------------------#
framework: sgl_remote # inference frame work of LLM, supporting: 'hf','vllm','fschat'
sgl_remote_url: "localhost:8002"
generator_model: "janhq/250404-llama-3.2-3b-instruct-grpo-03-s250" # name or path of the generator model, for laoding tokenizer
generator_max_input_len: 2048 # max length of the input
generation_params:
do_sample: False
max_tokens: 8192
# -------------------------------------------------Evaluation Settings------------------------------------------------#
# Metrics to evaluate the result
metrics: [ 'em','f1','acc','precision','recall']
# Specify setting for metric, will be called within certain metrics
metric_setting:
retrieval_recall_topk: 5
save_metric_score: True # whether to save the metric score into txt file

@ -0,0 +1,659 @@
import argparse
import json
import os
import re
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from copy import deepcopy
from datetime import datetime
from functools import wraps
import numpy as np
import requests
from flashrag.config import Config
from flashrag.generator.generator import BaseGenerator
from flashrag.pipeline import BasicPipeline
from flashrag.retriever.retriever import BaseTextRetriever
from flashrag.utils import get_dataset
from tqdm import tqdm
from transformers import AutoTokenizer
from config import logger
from src.agent import Agent, AgenticOutputs
from src.prompts import build_user_prompt, get_system_prompt
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
def retry(max_retries=10, sleep=1):
"""Decorator to retry a function with exponential backoff."""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
func_name = func.__name__
for attempt in range(max_retries):
try:
result = func(*args, **kwargs)
return result
except Exception as e:
logger.warning(f"Attempt {attempt + 1} of {func_name} failed: {e}")
if attempt == max_retries - 1:
logger.error(f"Function {func_name} failed after {max_retries} retries.", exc_info=True)
raise e
backoff_time = sleep * (2**attempt)
logger.info(f"Retrying {func_name} in {backoff_time:.2f} seconds...")
time.sleep(backoff_time)
logger.error(f"Function {func_name} retry logic finished unexpectedly.")
return None
return wrapper
return decorator
class RemoteRetriever(BaseTextRetriever):
"""A wrapper for remote retriever service with retry logic and logging."""
def __init__(self, config: Config):
"""Initializes the RemoteRetriever."""
super().__init__(config)
self.remote_url = f"http://{getattr(config, 'remote_retriever_url', 'localhost:8001')}"
self.topk = getattr(config, "retriever_topk", 5)
logger.info(f"🔗 Remote retriever URL: {self.remote_url}")
@retry(max_retries=3, sleep=2)
def _search(self, query: str, num: int | None = None, return_score: bool = False) -> list[dict]:
"""Search for documents using the remote retriever service."""
num = num if num is not None else self.topk
url = f"{self.remote_url}/search"
try:
response = requests.post(
url,
json={"query": query, "top_n": num, "return_score": return_score},
timeout=30,
)
response.raise_for_status()
results = response.json()
return results
except requests.exceptions.Timeout:
logger.error(f"Search request timed out after 30 seconds for query: {query[:50]}...")
raise
except requests.exceptions.ConnectionError:
logger.error(f"Could not connect to search service at {url}")
raise
except requests.exceptions.RequestException as e:
logger.error(f"Search request failed: {e}", exc_info=True)
raise
except Exception as e:
logger.error(f"Unexpected search error: {str(e)}", exc_info=True)
raise
@retry(max_retries=3, sleep=2)
def _batch_search(
self, queries: list[str], num: int | None = None, return_score: bool = False
) -> list[list[dict]]:
"""Batch search for documents using the remote retriever service."""
num = num if num is not None else self.topk
url = f"{self.remote_url}/batch_search"
try:
response = requests.post(
url,
json={"query": queries, "top_n": num, "return_score": return_score},
timeout=60,
)
response.raise_for_status()
results = response.json()
return results
except requests.exceptions.Timeout:
logger.error(f"Batch search request timed out after 60 seconds for {len(queries)} queries.")
raise
except requests.exceptions.ConnectionError:
logger.error(f"Could not connect to batch search service at {url}")
raise
except requests.exceptions.RequestException as e:
logger.error(f"Batch search request failed: {e}", exc_info=True)
raise
except Exception as e:
logger.error(f"Unexpected batch search error: {str(e)}", exc_info=True)
raise
class ReSearchPipeline(BasicPipeline):
"""Pipeline for ReSearch method using Agent for generation and tool use."""
def __init__(
self, config: Config, retriever: BaseTextRetriever | None = None, generator: BaseGenerator | None = None
):
"""Initializes the ReSearchPipeline."""
super().__init__(config)
logger.info("🔧 Initializing ReSearchPipeline...")
self.retriever = retriever or RemoteRetriever(config)
self.generator = generator or SGLRemoteGenerator(config)
try:
self.tokenizer = AutoTokenizer.from_pretrained(config.generator_model_path, trust_remote_code=True)
if not self.tokenizer.pad_token:
logger.warning("Tokenizer does not have a pad token; setting to eos_token.")
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.padding_side = "left"
logger.info("✅ Tokenizer initialized.")
except Exception as e:
logger.error(f"Failed to initialize tokenizer: {e}", exc_info=True)
raise
tokenizer_name = self.tokenizer.name_or_path.lower()
if "deepseek-ai/deepseek-r1-distill" in tokenizer_name:
adapter = R1DistilTokenizerAdapter()
elif "llama" in tokenizer_name:
adapter = LlamaTokenizerAdapter()
else:
logger.warning(f"Unknown tokenizer type '{tokenizer_name}', defaulting to R1DistilTokenizerAdapter.")
adapter = R1DistilTokenizerAdapter()
logger.info(f"🔩 Using Tokenizer Adapter: {type(adapter).__name__}")
def retriever_search(query: str, return_type=str, results: int = 2):
try:
search_results = self.retriever._search(query, num=results)
return self.format_search_results(search_results)
except Exception as e:
logger.error(f"Error during agent's retriever search for query '{query[:50]}...': {e}", exc_info=True)
return "<information>Search failed due to an internal error."
self.agent = Agent(adapter, search_fn=retriever_search)
logger.info("✅ Agent initialized.")
logger.info("✅ ReSearchPipeline initialized successfully.")
def format_search_results(self, search_results: list[dict]) -> str:
"""Formats search results into a string for the agent prompt."""
if not search_results:
return "<information>No results found.</information>"
max_content_len = 500
formatted = "\n-------\n".join(
[
f"Result {i + 1}: {r.get('contents', 'N/A')[:max_content_len]}{'...' if len(r.get('contents', '')) > max_content_len else ''}"
for i, r in enumerate(search_results)
]
)
formatted_str = f"<information>{formatted}</information>"
return formatted_str
def extract_search_query(self, text: str) -> str | None:
"""Extract search query from text between <search> tags."""
pattern = re.compile(r"<search>(.*?)</search>", re.DOTALL)
matches = pattern.findall(text)
if matches:
query = matches[-1].strip()
return query
return None
def extract_answer(self, text: str) -> str | None:
"""Extract answer from text between <answer> tags."""
pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL)
matches = pattern.findall(text)
if matches:
answer = matches[-1].strip()
return answer
return None
def run(self, dataset, do_eval: bool = True, pred_process_fun=None):
"""Runs the ReSearch pipeline on the dataset using the Agent."""
logger.info(f"🏃 Starting ReSearch pipeline run with {len(dataset)} items...")
try:
questions = [item.question if hasattr(item, "question") else item["question"] for item in dataset]
except (KeyError, AttributeError, TypeError) as e:
logger.error(f"Failed to extract questions from dataset items. Error: {e}", exc_info=True)
logger.error("Ensure dataset items have a 'question' key or attribute.")
return dataset
agent_max_generations = getattr(self.config, "agent_max_generations", 20)
generator_max_output_len = getattr(self.config, "generator_max_output_len", 1024)
try:
logger.info(f"🤖 Running agent inference for {len(questions)} questions...")
agent_outputs: AgenticOutputs = self.agent.run_agent(
generate_fn=self.generator.generate,
tokenizer=self.tokenizer,
questions=questions,
max_generations=agent_max_generations,
max_new_tokens=generator_max_output_len,
)
final_responses = agent_outputs.final_response_str
logger.info(f"✅ Agent inference completed. Received {len(final_responses)} final responses.")
except Exception as e:
logger.error(f"Agent run failed during inference: {e}", exc_info=True)
logger.warning("Agent run failed, attempting evaluation with potentially incomplete results.")
for item in dataset:
if hasattr(item, "update_output"):
item.update_output("pred", "AGENT_ERROR")
elif isinstance(item, dict):
item["pred"] = "AGENT_ERROR"
logger.info("📝 Extracting answers and updating dataset items...")
num_updated = 0
num_missing_answers = 0
if len(final_responses) == len(dataset):
for i, item in enumerate(dataset):
response = final_responses[i]
answer = self.extract_answer(response)
pred_to_save = answer if answer is not None else ""
if answer is None:
num_missing_answers += 1
if hasattr(item, "update_output"):
item.update_output("pred", pred_to_save)
item.update_output("final_response", response)
num_updated += 1
elif isinstance(item, dict):
item["pred"] = pred_to_save
item["final_response"] = response
num_updated += 1
else:
logger.warning(f"Item {i} has unknown type {type(item)}, cannot update with prediction.")
logger.info(f"Updated {num_updated}/{len(dataset)} dataset items with predictions.")
if num_missing_answers > 0:
logger.warning(f"{num_missing_answers} items had no <answer> tag.")
else:
logger.error(
f"Mismatch between dataset size ({len(dataset)}) and number of agent responses ({len(final_responses)}). Cannot reliably update dataset."
)
for item in dataset:
if hasattr(item, "update_output"):
item.update_output("pred", "RESPONSE_COUNT_MISMATCH")
elif isinstance(item, dict):
item["pred"] = "RESPONSE_COUNT_MISMATCH"
if do_eval:
logger.info("📊 Evaluating results using BasicPipeline.evaluate...")
try:
dataset = self.evaluate(dataset, do_eval=True, pred_process_fun=pred_process_fun)
logger.info("✅ Evaluation completed via base class method.")
except Exception as e:
logger.error(f"Error during BasicPipeline.evaluate: {e}", exc_info=True)
logger.warning("Evaluation may be incomplete.")
else:
logger.info("Skipping evaluation step as do_eval=False.")
logger.info("✅ ReSearch pipeline run finished.")
return dataset
class SGLRemoteGenerator(BaseGenerator):
"""Class for decoder-only generator, based on SGLang remote service."""
def __init__(self, config: Config):
"""Initializes the SGLRemoteGenerator."""
super().__init__(config)
logger.info("🔧 Initializing SGLRemoteGenerator...")
sgl_url = getattr(config, "sgl_remote_url", "localhost:8002")
self.sgl_remote_url = f"http://{sgl_url}/generate"
self.health_check_url = f"http://{sgl_url}/health"
logger.info(f"🔗 Remote Generator URL: {self.sgl_remote_url}")
self.model_path = getattr(config, "generator_model_path", None)
if not self.model_path:
logger.error("generator_model_path not found in config!")
raise ValueError("generator_model_path is required for SGLRemoteGenerator")
try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
logger.info("✅ Tokenizer loaded for generator.")
except Exception as e:
logger.error(f"Failed to load tokenizer for generator from {self.model_path}: {e}", exc_info=True)
raise
self.generation_params = getattr(config, "generation_params", {})
self.config = config
self._check_health()
def _check_health(self):
"""Checks the health of the remote generator service."""
try:
test_response = requests.get(self.health_check_url, timeout=5)
test_response.raise_for_status()
logger.info("✅ Remote generator service is available")
except requests.exceptions.RequestException as e:
logger.error(f"Could not connect or verify remote generator service at {self.health_check_url}: {str(e)}")
logger.warning("Please ensure the SGLang service is running and accessible.")
@retry(max_retries=5, sleep=2)
def generate(
self,
input_list: list[str] | str,
return_raw_output: bool = False,
return_scores: bool = False,
**params,
) -> list[str] | tuple[list[str], list[list[float]]] | list[dict]:
"""Generates text using the remote SGLang service."""
if isinstance(input_list, str):
input_list = [input_list]
if not isinstance(input_list, list) or not all(isinstance(item, str) for item in input_list):
raise ValueError("Input must be a string or a list of strings.")
batch_size = len(input_list)
data_to_remote = {"text": input_list}
effective_params = deepcopy(self.generation_params)
effective_params.update(params)
curr_sampling_params = {}
if effective_params.get("do_sample", True) is False:
curr_sampling_params["temperature"] = 0.0
else:
curr_sampling_params["temperature"] = effective_params.get(
"temperature", getattr(self.config, "temperature", 0.7)
)
default_max_tokens = getattr(self.config, "generator_max_output_len", 1024)
curr_sampling_params["max_new_tokens"] = effective_params.get("max_new_tokens", default_max_tokens)
stop_sequences = effective_params.get("stop", [])
if isinstance(stop_sequences, str):
stop_sequences = [stop_sequences]
if stop_sequences:
curr_sampling_params["stop"] = stop_sequences
keys_to_remove = ["do_sample", "temperature", "max_new_tokens", "stop"]
for key in keys_to_remove:
effective_params.pop(key, None)
if "top_p" in effective_params:
curr_sampling_params["top_p"] = effective_params["top_p"]
if "top_k" in effective_params:
curr_sampling_params["top_k"] = effective_params["top_k"]
data_to_remote["sampling_params"] = curr_sampling_params
if return_scores:
data_to_remote["return_logprob"] = True
data_to_remote["top_logprobs_num"] = getattr(self.config, "top_logprobs_num", 2)
try:
response = requests.post(
self.sgl_remote_url, json=data_to_remote, timeout=120, headers={"Content-Type": "application/json"}
)
response.raise_for_status()
response_list = response.json()
if return_raw_output:
return response_list
generated_text = []
for item in response_list:
text = item.get("text", "")
finish_reason = item.get("meta_info", {}).get("finish_reason", {})
matched_stop = finish_reason.get("matched")
if matched_stop and curr_sampling_params.get("stop") and matched_stop in curr_sampling_params["stop"]:
text += matched_stop
generated_text.append(text)
if return_scores:
scores = []
for resp_item in response_list:
logprobs_list = resp_item.get("meta_info", {}).get("output_token_logprobs", [])
token_scores = [
np.exp(logprob[0]) if (logprob and len(logprob) > 0) else 0.0 for logprob in logprobs_list
]
scores.append(token_scores)
return generated_text, scores
else:
return generated_text
except requests.exceptions.Timeout:
logger.error("Generation request timed out after 120 seconds.")
raise
except requests.exceptions.ConnectionError:
logger.error(f"Could not connect to remote generator service at {self.sgl_remote_url}.")
raise
except requests.exceptions.RequestException as e:
logger.error(f"Network error during generation: {str(e)}", exc_info=True)
raise
except json.JSONDecodeError as e:
response_text = "Unknown (error occurred before response object assignment)"
if "response" in locals() and hasattr(response, "text"):
response_text = response.text[:500]
logger.error(
f"Failed to decode JSON response from {self.sgl_remote_url}. Response text: {response_text}...",
exc_info=True,
)
raise
except Exception as e:
logger.error(f"Unexpected error during generation: {str(e)}", exc_info=True)
raise
def load_dataset_items(config: Config, split: str) -> list[dict | object]:
"""Loads dataset items using flashrag's get_dataset."""
logger.info(f"📚 Loading dataset: {config.dataset_name}, Split: {split}")
try:
all_splits = get_dataset(config)
if split not in all_splits:
logger.error(
f"Split '{split}' not found in dataset '{config.dataset_name}'. Available splits: {list(all_splits.keys())}"
)
return []
dataset_items = all_splits[split]
logger.info(f"Successfully loaded {len(dataset_items)} items for split '{split}'.")
return dataset_items
except FileNotFoundError:
logger.error(
f"Dataset files not found for '{config.dataset_name}' in '{config.data_dir}'. Check config and paths."
)
return []
except Exception as e:
logger.error(f"Error loading dataset using get_dataset: {e}", exc_info=True)
return []
def save_results(args: argparse.Namespace, config: Config, result_dataset, run_duration: float):
"""Saves summary and debug information."""
logger.info("💾 Saving results...")
summary_file = os.path.join(args.save_dir, f"{args.save_note}_summary.txt")
debug_file = os.path.join(args.save_dir, f"{args.save_note}_debug.json")
num_items = len(result_dataset)
logger.info(f"Saving summary results to {summary_file}...")
try:
with open(summary_file, "w", encoding="utf-8") as f:
f.write("EVALUATION SUMMARY\n")
f.write("=================\n\n")
f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"Run Duration: {run_duration:.2f} seconds\n")
f.write(f"Dataset: {config.dataset_name} ({args.split} split)\n")
f.write(f"Model: {config.generator_model_path}\n")
f.write(f"Retriever: {config.remote_retriever_url}\n")
f.write(f"Agent Max Generations: {getattr(config, 'agent_max_generations', 'N/A')}\n")
f.write(f"Generator Max Output Len: {getattr(config, 'generator_max_output_len', 'N/A')}\n\n")
f.write(f"Total items processed: {num_items}\n")
f.write("\nNote: Verification was skipped in this run.\n")
f.write("Note: Overall metrics (like EM, F1) are usually printed to console by evaluate method.\n")
logger.info(f"✅ Summary saved to {summary_file}")
except Exception as e:
logger.error(f"Error saving summary file '{summary_file}': {e}", exc_info=True)
logger.info(f"Saving debug information (predictions & responses) to {debug_file}...")
try:
debug_data = []
for i, item in enumerate(result_dataset):
item_data: dict[str, object] = {}
def get_item_value(data_item, key_or_attr: str) -> str | int | float | list | bool | None:
if isinstance(data_item, dict):
return data_item.get(key_or_attr)
elif hasattr(data_item, key_or_attr):
return getattr(data_item, key_or_attr)
return None
item_data["item_index"] = i
item_data["question"] = get_item_value(item, "question")
item_data["prediction"] = get_item_value(item, "pred")
item_data["final_response"] = get_item_value(item, "final_response")
gt_answer_val = None
try:
gt_answer_val = get_item_value(item, "answer")
if gt_answer_val is None:
answers_list = get_item_value(item, "answers")
if isinstance(answers_list, list) and answers_list:
raw_ans = answers_list[0]
if isinstance(raw_ans, (str, int, float, bool)):
gt_answer_val = raw_ans
else:
gt_answer_val = str(raw_ans)
elif not isinstance(gt_answer_val, (str, int, float, bool)):
gt_answer_val = str(gt_answer_val)
except Exception as e:
logger.warning(f"Could not safely get ground truth for item {i}: {e}")
gt_answer_val = "ERROR_GETTING_ANSWER"
item_data["ground_truth"] = gt_answer_val
eval_score_val = None
try:
eval_score_val = get_item_value(item, "score")
if not isinstance(eval_score_val, (str, int, float, bool, type(None))):
eval_score_val = str(eval_score_val)
except Exception as e:
logger.warning(f"Could not safely get score for item {i}: {e}")
eval_score_val = "ERROR_GETTING_SCORE"
item_data["eval_score"] = eval_score_val
debug_data.append(item_data)
with open(debug_file, "w", encoding="utf-8") as f:
json.dump(debug_data, f, indent=2, ensure_ascii=False)
logger.info(f"✅ Debug information saved to {debug_file}")
except Exception as e:
logger.error(f"Error saving debug file '{debug_file}': {e}", exc_info=True)
def research(args: argparse.Namespace, config: Config):
"""Main function to run the research evaluation pipeline."""
logger.info("🚀 Starting research pipeline execution...")
start_time = time.time()
test_data = load_dataset_items(config, args.split)
if not test_data:
logger.error("Failed to load test data. Exiting.")
return
try:
logger.info("🏗️ Building ReSearchPipeline...")
pipeline = ReSearchPipeline(config)
logger.info("✅ Pipeline built successfully.")
except Exception as e:
logger.error(f"Failed to initialize ReSearchPipeline: {e}", exc_info=True)
return
agent_max_generations = getattr(config, "agent_max_generations", 20)
generator_max_output_len = getattr(config, "generator_max_output_len", 1024)
try:
logger.info("🏃 Starting pipeline run...")
result_dataset = pipeline.run(test_data, do_eval=True)
logger.info("✅ Pipeline run completed.")
except Exception as e:
logger.error(f"Error during pipeline run: {e}", exc_info=True)
result_dataset = test_data
logger.warning("Pipeline run failed, attempting to save inputs/partial results.")
run_duration = time.time() - start_time
logger.info(f"Total run duration: {run_duration:.2f} seconds.")
save_results(args, config, result_dataset, run_duration)
logger.info("🏁 Research pipeline execution finished.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Running ReSearch Evaluation Pipeline")
parser.add_argument(
"--config_path", type=str, default="./eval_config.yaml", help="Path to the main FlashRAG config file."
)
parser.add_argument(
"--dataset_name",
type=str,
default="bamboogle",
help="Name of the dataset (must match config or data_dir structure).",
)
parser.add_argument(
"--split", type=str, default="test", help="Dataset split to evaluate (e.g., test, validation)."
)
parser.add_argument("--save_dir", type=str, default="./output_logs", help="Directory to save logs and results.")
parser.add_argument("--save_note", type=str, default="research_run", help="A note to prepend to saved filenames.")
parser.add_argument("--data_dir", type=str, help="Override data directory specified in config.")
parser.add_argument(
"--sgl_remote_url", type=str, help="Override SGLang remote generator URL (e.g., localhost:8002)."
)
parser.add_argument(
"--remote_retriever_url", type=str, help="Override remote retriever URL (e.g., localhost:8001)."
)
parser.add_argument("--generator_model_path", type=str, help="Override generator model path specified in config.")
parser.add_argument("--retriever_topk", type=int, help="Override retriever top K.")
parser.add_argument("--generator_max_output_len", type=int, help="Override generator max output length.")
parser.add_argument("--agent_max_generations", type=int, help="Override agent max interaction turns.")
args = parser.parse_args()
logger.info(f"Starting evaluation script with arguments: {args}")
try:
os.makedirs(args.save_dir, exist_ok=True)
logger.info(f"💾 Logs and results will be saved to: {args.save_dir}")
except OSError as e:
logger.error(f"Could not create save directory '{args.save_dir}': {e}", exc_info=True)
exit(1)
config_overrides = {
k: v
for k, v in vars(args).items()
if v is not None
and k
not in [
"config_path",
"dataset_name",
"split",
"save_dir",
"save_note",
]
}
logger.info(f"🔧 Loading configuration from: {args.config_path}")
try:
config = Config(args.config_path, config_dict=config_overrides)
config.dataset_name = args.dataset_name
if args.data_dir:
config.data_dir = args.data_dir
logger.info(f"Effective data_dir: {getattr(config, 'data_dir', 'N/A')}")
logger.info(f"Effective generator_model_path: {getattr(config, 'generator_model_path', 'N/A')}")
logger.info(f"Effective sgl_remote_url: {getattr(config, 'sgl_remote_url', 'N/A')}")
logger.info(f"Effective remote_retriever_url: {getattr(config, 'remote_retriever_url', 'N/A')}")
logger.info("✅ Config loaded and potentially overridden by CLI arguments.")
config["dataset_path"] = os.path.join(config.data_dir, config.dataset_name)
except FileNotFoundError:
logger.error(f"Config file not found at '{args.config_path}'. Please check the path.")
exit(1)
except Exception as e:
logger.error(f"Error loading or processing configuration: {e}", exc_info=True)
exit(1)
research(args, config)

@ -0,0 +1,68 @@
import argparse
import os
import zipfile
from dotenv import load_dotenv
from huggingface_hub import snapshot_download
from config import DATA_DIR
def parse_args() -> argparse.Namespace:
"""Parse command line arguments.
Returns:
argparse.Namespace: Parsed arguments
"""
parser = argparse.ArgumentParser(description="Download FlashRAG datasets from HuggingFace Hub")
parser.add_argument(
"--repo-id",
type=str,
default="RUC-NLPIR/FlashRAG_datasets",
help="HuggingFace repository IDs",
)
parser.add_argument(
"--local-dir",
type=str,
default=DATA_DIR / "flashrag_datasets",
help="Local directory to save model",
)
return parser.parse_args()
def main():
"""Main function to download model."""
args = parse_args()
load_dotenv(override=True)
# Configuration
HF_TOKEN = os.getenv("HF_TOKEN")
ALLOW_PATTERNS = [
"*retrieval-corpus*",
"*bamboogle*",
"*nq*",
]
# Download the model
snapshot_download(
token=HF_TOKEN,
repo_id=args.repo_id,
local_dir=args.local_dir,
repo_type="dataset",
# ignore_patterns=IGNORE_PATTERNS,
allow_patterns=ALLOW_PATTERNS,
)
# unzip data/flashrag_datasets/retrieval-corpus/wiki18_100w.zip
print("Unzipping wiki18_100w.zip. Might take a while...")
zip_file_path = os.path.join(args.local_dir, "retrieval-corpus", "wiki18_100w.zip")
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
zip_ref.extractall(args.local_dir)
print(f"✅ Done: {args.repo_id} -> {args.local_dir}")
if __name__ == "__main__":
main()

@ -0,0 +1,43 @@
"""Download and extract FlashRAG index."""
import os
import zipfile
import requests
from tqdm import tqdm
from config import DATA_DIR
# Constants
URL = "https://www.modelscope.cn/datasets/hhjinjiajie/FlashRAG_Dataset/resolve/master/retrieval_corpus/wiki18_100w_e5_index.zip"
ZIP_NAME = "wiki18_100w_e5_index.zip"
zip_path = DATA_DIR / ZIP_NAME
# Download with progress bar
print("📥 Downloading index...")
response = requests.get(URL, stream=True)
total_size = int(response.headers.get("content-length", 0))
with (
open(zip_path, "wb") as f,
tqdm(
desc=ZIP_NAME,
total=total_size,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar,
):
for data in response.iter_content(chunk_size=1024):
size = f.write(data)
bar.update(size)
# Extract
print("📦 Extracting index...")
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(DATA_DIR)
# Clean up zip
os.remove(zip_path)
print("✅ Download and extraction completed successfully!")
print(f"Index file is at: {DATA_DIR}/data00/jiajie_jin/flashrag_indexes/wiki_dpr_100w/e5_flat_inner.index")

@ -0,0 +1,54 @@
import argparse
import os
from dotenv import load_dotenv
from huggingface_hub import snapshot_download
from config import GENERATOR_MODEL_DIR, GENERATOR_MODEL_REPO_ID
def parse_args() -> argparse.Namespace:
"""Parse command line arguments.
Returns:
argparse.Namespace: Parsed arguments
"""
parser = argparse.ArgumentParser(description="Download model from HuggingFace Hub")
parser.add_argument(
"--repo-id",
type=str,
default=GENERATOR_MODEL_REPO_ID,
help="HuggingFace repository ID",
)
parser.add_argument(
"--local-dir",
type=str,
default=GENERATOR_MODEL_DIR,
help="Local directory to save model",
)
return parser.parse_args()
def main():
"""Main function to download model."""
args = parse_args()
load_dotenv(override=True)
# Configuration
HF_TOKEN = os.getenv("HF_TOKEN")
print("Downloading model to", args.local_dir)
# Download the model
snapshot_download(
token=HF_TOKEN,
repo_id=args.repo_id,
local_dir=args.local_dir,
repo_type="model",
)
print(f"✅ Done: {args.repo_id} -> {args.local_dir}")
if __name__ == "__main__":
main()

@ -0,0 +1,54 @@
import argparse
import os
from dotenv import load_dotenv
from huggingface_hub import snapshot_download
from config import RETRIEVER_MODEL_DIR, RETRIEVER_MODEL_REPO_ID
def parse_args() -> argparse.Namespace:
"""Parse command line arguments.
Returns:
argparse.Namespace: Parsed arguments
"""
parser = argparse.ArgumentParser(description="Download model from HuggingFace Hub")
parser.add_argument(
"--repo-id",
type=str,
default=RETRIEVER_MODEL_REPO_ID,
help="HuggingFace repository ID",
)
parser.add_argument(
"--local-dir",
type=str,
default=RETRIEVER_MODEL_DIR,
help="Local directory to save model",
)
return parser.parse_args()
def main():
"""Main function to download model."""
args = parse_args()
load_dotenv(override=True)
# Configuration
HF_TOKEN = os.getenv("HF_TOKEN")
print("Downloading model to", args.local_dir)
# Download the model
snapshot_download(
token=HF_TOKEN,
repo_id=args.repo_id,
local_dir=args.local_dir,
repo_type="model",
)
print(f"✅ Done: {args.repo_id} -> {args.local_dir}")
if __name__ == "__main__":
main()

@ -0,0 +1,9 @@
# ------------------------------------------------Environment Settings------------------------------------------------#
gpu_id: "0"
# -------------------------------------------------Retrieval Settings------------------------------------------------#
# If set the name, the model path will be find in global paths
retrieval_method: "e5" # name or path of the retrieval model.
index_path: "/mnt/nas/thinhlpg/data/data00/jiajie_jin/flashrag_indexes/wiki_dpr_100w/e5_flat_inner.index" # path to the indexed file
faiss_gpu: False # whether use gpu to hold index
corpus_path: "/mnt/nas/thinhlpg/code/DeepSearch/data/flashrag_datasets/wiki18_100w.jsonl" # path to corpus in '.jsonl' format that store the documents

@ -0,0 +1,125 @@
import subprocess
import sys
from pathlib import Path
# Add project root to sys.path to allow importing config
# Assuming the script is at DeepSearch/scripts/serving/serve_generator.py
# The project root (DeepSearch) is parents[2]
PROJ_ROOT = Path(__file__).resolve().parents[2]
if str(PROJ_ROOT) not in sys.path:
sys.path.append(str(PROJ_ROOT))
# Import after adjusting sys.path
try:
from config import (
GENERATOR_MODEL_REPO_ID,
GENERATOR_SERVER_PORT,
MODEL_CONFIG,
logger,
)
except ImportError as e:
# Use print here as logger might not be available if import failed
print(
f"Error importing config: {e}. Make sure config.py is in the project root ({PROJ_ROOT}) and added to sys.path."
)
sys.exit(1)
def launch_sglang_server(
model_id: str,
port: int,
context_length: int,
host: str = "0.0.0.0",
dtype: str = "bfloat16",
) -> None:
"""Launches the SGLang server using specified configurations.
Args:
model_id: The Hugging Face repository ID of the model.
port: The port number for the server.
context_length: The maximum context length for the model.
host: The host address for the server.
dtype: The data type for the model (e.g., 'bfloat16', 'float16').
"""
command = [
sys.executable, # Use the current Python interpreter
"-m",
"sglang.launch_server",
"--model-path",
model_id,
"--context-length",
str(context_length),
"--enable-metrics",
"--dtype",
dtype,
"--host",
host,
"--port",
str(port),
"--trust-remote-code",
# Recommended by SGLang for stability sometimes
"--disable-overlap",
# Can sometimes cause issues
"--disable-radix-cache",
]
# Log the command clearly
command_str = " ".join(command)
logger.info(f"🚀 Launching SGLang server with command: {command_str}")
process = None # Initialize process to None
try:
# Use Popen to start the server process
# It runs in the foreground relative to this script,
# but allows us to catch KeyboardInterrupt cleanly.
process = subprocess.Popen(command)
# Wait for the process to complete (e.g., user interruption)
process.wait()
# Check return code after waiting
if process.returncode != 0:
logger.error(f"💥 SGLang server process exited with error code: {process.returncode}")
sys.exit(process.returncode)
else:
logger.info("✅ SGLang server process finished gracefully.")
except FileNotFoundError:
logger.error(f"💥 Error: Python executable or sglang module not found.")
logger.error(f"Ensure '{sys.executable}' is correct and sglang is installed.")
sys.exit(1)
except KeyboardInterrupt:
logger.info("🛑 SGLang server launch interrupted by user. Stopping server...")
# Attempt to terminate the process gracefully
if process and process.poll() is None: # Check if process exists and is running
process.terminate()
try:
process.wait(timeout=5) # Wait a bit for termination
logger.info("✅ Server terminated gracefully.")
except subprocess.TimeoutExpired:
logger.warning("⚠️ Server did not terminate gracefully, forcing kill.")
process.kill()
sys.exit(0) # Exit cleanly after interrupt
except Exception as e:
# Catch any other unexpected exceptions during launch or waiting
logger.error(f"🚨 An unexpected error occurred: {e}")
# Ensure process is cleaned up if it exists
if process and process.poll() is None:
process.kill()
sys.exit(1)
if __name__ == "__main__":
# Get context length from config, default to 8192
context_len = MODEL_CONFIG.get("max_seq_length", 8192)
logger.info("----------------------------------------------------")
logger.info("✨ Starting SGLang Generator Server ✨")
logger.info(f" Model ID: {GENERATOR_MODEL_REPO_ID}")
logger.info(f" Port: {GENERATOR_SERVER_PORT}")
logger.info(f" Context Length: {context_len}")
logger.info("----------------------------------------------------")
launch_sglang_server(
model_id=GENERATOR_MODEL_REPO_ID,
port=GENERATOR_SERVER_PORT,
context_length=context_len,
)

@ -0,0 +1,113 @@
import argparse
import asyncio
from collections import deque
from typing import List, Tuple, Union
from fastapi import FastAPI, HTTPException
from flashrag.config import Config
from flashrag.utils import get_retriever
from pydantic import BaseModel
from config import RETRIEVER_SERVER_PORT
app = FastAPI()
retriever_list = []
available_retrievers = deque()
retriever_semaphore = None
def init_retriever(args):
global retriever_semaphore
config = Config(args.config)
for i in range(args.num_retriever):
print(f"Initializing retriever {i + 1}/{args.num_retriever}")
retriever = get_retriever(config)
retriever_list.append(retriever)
available_retrievers.append(i)
# create a semaphore to limit the number of retrievers that can be used concurrently
retriever_semaphore = asyncio.Semaphore(args.num_retriever)
@app.get("/health")
async def health_check():
return {"status": "healthy", "retrievers": {"total": len(retriever_list), "available": len(available_retrievers)}}
class QueryRequest(BaseModel):
query: str
top_n: int = 10
return_score: bool = False
class BatchQueryRequest(BaseModel):
query: List[str]
top_n: int = 10
return_score: bool = False
class Document(BaseModel):
id: str
contents: str
@app.post("/search", response_model=Union[Tuple[List[Document], List[float]], List[Document]])
async def search(request: QueryRequest):
query = request.query
top_n = request.top_n
return_score = request.return_score
if not query or not query.strip():
print(f"Query content cannot be empty: {query}")
raise HTTPException(status_code=400, detail="Query content cannot be empty")
async with retriever_semaphore:
retriever_idx = available_retrievers.popleft()
try:
if return_score:
results, scores = retriever_list[retriever_idx].search(query, top_n, return_score)
return [Document(id=result["id"], contents=result["contents"]) for result in results], scores
else:
results = retriever_list[retriever_idx].search(query, top_n, return_score)
return [Document(id=result["id"], contents=result["contents"]) for result in results]
finally:
available_retrievers.append(retriever_idx)
@app.post("/batch_search", response_model=Union[List[List[Document]], Tuple[List[List[Document]], List[List[float]]]])
async def batch_search(request: BatchQueryRequest):
query = request.query
top_n = request.top_n
return_score = request.return_score
async with retriever_semaphore:
retriever_idx = available_retrievers.popleft()
try:
if return_score:
results, scores = retriever_list[retriever_idx].batch_search(query, top_n, return_score)
return [
[Document(id=result["id"], contents=result["contents"]) for result in results[i]]
for i in range(len(results))
], scores
else:
results = retriever_list[retriever_idx].batch_search(query, top_n, return_score)
return [
[Document(id=result["id"], contents=result["contents"]) for result in results[i]]
for i in range(len(results))
]
finally:
available_retrievers.append(retriever_idx)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="./retriever_config.yaml")
parser.add_argument("--num_retriever", type=int, default=1)
parser.add_argument("--port", type=int, default=RETRIEVER_SERVER_PORT)
args = parser.parse_args()
init_retriever(args)
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=args.port)
Loading…
Cancel
Save