From 14ef79a4f532e1563a1d3430d69406c23e8b506b Mon Sep 17 00:00:00 2001 From: thinhlpg Date: Thu, 10 Apr 2025 06:48:14 +0000 Subject: [PATCH] feat: [WIP] add bench scripts --- .gitignore | 1 + pyproject.toml | 1 + scripts/evaluation/eval_config.yaml | 46 ++ scripts/evaluation/run_eval.py | 659 ++++++++++++++++++ scripts/serving/download_flashrag_datasets.py | 68 ++ scripts/serving/download_flashrag_index.py | 43 ++ scripts/serving/download_generator_model.py | 54 ++ scripts/serving/download_retriever_model.py | 54 ++ scripts/serving/retriever_config.yaml | 9 + scripts/serving/serve_generator.py | 125 ++++ scripts/serving/serve_retriever.py | 113 +++ 11 files changed, 1173 insertions(+) create mode 100644 scripts/evaluation/eval_config.yaml create mode 100644 scripts/evaluation/run_eval.py create mode 100644 scripts/serving/download_flashrag_datasets.py create mode 100644 scripts/serving/download_flashrag_index.py create mode 100644 scripts/serving/download_generator_model.py create mode 100644 scripts/serving/download_retriever_model.py create mode 100644 scripts/serving/retriever_config.yaml create mode 100644 scripts/serving/serve_generator.py create mode 100644 scripts/serving/serve_retriever.py diff --git a/.gitignore b/.gitignore index 9348aa4..1a2d02e 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,7 @@ logs/ *.code-workspace data/ .gradio/ +output_* # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/pyproject.toml b/pyproject.toml index d4946c9..bc7d10c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,4 +38,5 @@ dependencies = [ "requests>=2.31.0", "tqdm>=4.66.1", "tavily-python", + "sglang[all]>=0.4.5", ] \ No newline at end of file diff --git a/scripts/evaluation/eval_config.yaml b/scripts/evaluation/eval_config.yaml new file mode 100644 index 0000000..ad503d6 --- /dev/null +++ b/scripts/evaluation/eval_config.yaml @@ -0,0 +1,46 @@ +# ------------------------------------------------Environment Settings------------------------------------------------# +# Directory paths for data and outputs +data_dir: "/mnt/nas/thinhlpg/code/DeepSearch/data/flashrag_datasets/" +save_dir: "/mnt/nas/thinhlpg/code/DeepSearch/logs" + +# Seed for reproducibility +seed: 2024 + +# Whether save intermediate data +save_intermediate_data: True +save_note: 'experiment' + +# -------------------------------------------------Retrieval Settings------------------------------------------------# +# If set the remote url, the retriever will be a remote retriever and ignore following settings +use_remote_retriever: True +remote_retriever_url: "localhost:8001" + +instruction: ~ # instruction for retrieval model +retrieval_topk: 5 # number of retrieved documents +retrieval_batch_size: 256 # batch size for retrieval +retrieval_use_fp16: True # whether to use fp16 for retrieval model +retrieval_query_max_length: 128 # max length of the query +save_retrieval_cache: False # whether to save the retrieval cache +use_retrieval_cache: False # whether to use the retrieval cache +retrieval_cache_path: ~ # path to the retrieval cache +retrieval_pooling_method: ~ # set automatically if not provided + +# -------------------------------------------------Generator Settings------------------------------------------------# +framework: sgl_remote # inference frame work of LLM, supporting: 'hf','vllm','fschat' +sgl_remote_url: "localhost:8002" +generator_model: "janhq/250404-llama-3.2-3b-instruct-grpo-03-s250" # name or path of the generator model, for laoding tokenizer +generator_max_input_len: 2048 # max length of the input +generation_params: + do_sample: False + max_tokens: 8192 + +# -------------------------------------------------Evaluation Settings------------------------------------------------# +# Metrics to evaluate the result +metrics: [ 'em','f1','acc','precision','recall'] +# Specify setting for metric, will be called within certain metrics +metric_setting: + retrieval_recall_topk: 5 +save_metric_score: True #γ€€whether to save the metric score into txt file + + + diff --git a/scripts/evaluation/run_eval.py b/scripts/evaluation/run_eval.py new file mode 100644 index 0000000..a0ccb16 --- /dev/null +++ b/scripts/evaluation/run_eval.py @@ -0,0 +1,659 @@ +import argparse +import json +import os +import re +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from copy import deepcopy +from datetime import datetime +from functools import wraps + +import numpy as np +import requests +from flashrag.config import Config +from flashrag.generator.generator import BaseGenerator +from flashrag.pipeline import BasicPipeline +from flashrag.retriever.retriever import BaseTextRetriever +from flashrag.utils import get_dataset +from tqdm import tqdm +from transformers import AutoTokenizer + +from config import logger +from src.agent import Agent, AgenticOutputs +from src.prompts import build_user_prompt, get_system_prompt +from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter + + +def retry(max_retries=10, sleep=1): + """Decorator to retry a function with exponential backoff.""" + + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + func_name = func.__name__ + for attempt in range(max_retries): + try: + result = func(*args, **kwargs) + return result + except Exception as e: + logger.warning(f"Attempt {attempt + 1} of {func_name} failed: {e}") + if attempt == max_retries - 1: + logger.error(f"Function {func_name} failed after {max_retries} retries.", exc_info=True) + raise e + backoff_time = sleep * (2**attempt) + logger.info(f"Retrying {func_name} in {backoff_time:.2f} seconds...") + time.sleep(backoff_time) + logger.error(f"Function {func_name} retry logic finished unexpectedly.") + return None + + return wrapper + + return decorator + + +class RemoteRetriever(BaseTextRetriever): + """A wrapper for remote retriever service with retry logic and logging.""" + + def __init__(self, config: Config): + """Initializes the RemoteRetriever.""" + super().__init__(config) + self.remote_url = f"http://{getattr(config, 'remote_retriever_url', 'localhost:8001')}" + self.topk = getattr(config, "retriever_topk", 5) + logger.info(f"πŸ”— Remote retriever URL: {self.remote_url}") + + @retry(max_retries=3, sleep=2) + def _search(self, query: str, num: int | None = None, return_score: bool = False) -> list[dict]: + """Search for documents using the remote retriever service.""" + num = num if num is not None else self.topk + url = f"{self.remote_url}/search" + + try: + response = requests.post( + url, + json={"query": query, "top_n": num, "return_score": return_score}, + timeout=30, + ) + response.raise_for_status() + + results = response.json() + return results + except requests.exceptions.Timeout: + logger.error(f"Search request timed out after 30 seconds for query: {query[:50]}...") + raise + except requests.exceptions.ConnectionError: + logger.error(f"Could not connect to search service at {url}") + raise + except requests.exceptions.RequestException as e: + logger.error(f"Search request failed: {e}", exc_info=True) + raise + except Exception as e: + logger.error(f"Unexpected search error: {str(e)}", exc_info=True) + raise + + @retry(max_retries=3, sleep=2) + def _batch_search( + self, queries: list[str], num: int | None = None, return_score: bool = False + ) -> list[list[dict]]: + """Batch search for documents using the remote retriever service.""" + num = num if num is not None else self.topk + url = f"{self.remote_url}/batch_search" + + try: + response = requests.post( + url, + json={"query": queries, "top_n": num, "return_score": return_score}, + timeout=60, + ) + response.raise_for_status() + results = response.json() + return results + except requests.exceptions.Timeout: + logger.error(f"Batch search request timed out after 60 seconds for {len(queries)} queries.") + raise + except requests.exceptions.ConnectionError: + logger.error(f"Could not connect to batch search service at {url}") + raise + except requests.exceptions.RequestException as e: + logger.error(f"Batch search request failed: {e}", exc_info=True) + raise + except Exception as e: + logger.error(f"Unexpected batch search error: {str(e)}", exc_info=True) + raise + + +class ReSearchPipeline(BasicPipeline): + """Pipeline for ReSearch method using Agent for generation and tool use.""" + + def __init__( + self, config: Config, retriever: BaseTextRetriever | None = None, generator: BaseGenerator | None = None + ): + """Initializes the ReSearchPipeline.""" + super().__init__(config) + logger.info("πŸ”§ Initializing ReSearchPipeline...") + + self.retriever = retriever or RemoteRetriever(config) + + self.generator = generator or SGLRemoteGenerator(config) + + try: + self.tokenizer = AutoTokenizer.from_pretrained(config.generator_model_path, trust_remote_code=True) + if not self.tokenizer.pad_token: + logger.warning("Tokenizer does not have a pad token; setting to eos_token.") + self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.padding_side = "left" + logger.info("βœ… Tokenizer initialized.") + except Exception as e: + logger.error(f"Failed to initialize tokenizer: {e}", exc_info=True) + raise + + tokenizer_name = self.tokenizer.name_or_path.lower() + + if "deepseek-ai/deepseek-r1-distill" in tokenizer_name: + adapter = R1DistilTokenizerAdapter() + elif "llama" in tokenizer_name: + adapter = LlamaTokenizerAdapter() + else: + logger.warning(f"Unknown tokenizer type '{tokenizer_name}', defaulting to R1DistilTokenizerAdapter.") + adapter = R1DistilTokenizerAdapter() + logger.info(f"πŸ”© Using Tokenizer Adapter: {type(adapter).__name__}") + + def retriever_search(query: str, return_type=str, results: int = 2): + try: + search_results = self.retriever._search(query, num=results) + return self.format_search_results(search_results) + except Exception as e: + logger.error(f"Error during agent's retriever search for query '{query[:50]}...': {e}", exc_info=True) + return "Search failed due to an internal error." + + self.agent = Agent(adapter, search_fn=retriever_search) + logger.info("βœ… Agent initialized.") + logger.info("βœ… ReSearchPipeline initialized successfully.") + + def format_search_results(self, search_results: list[dict]) -> str: + """Formats search results into a string for the agent prompt.""" + if not search_results: + return "No results found." + max_content_len = 500 + formatted = "\n-------\n".join( + [ + f"Result {i + 1}: {r.get('contents', 'N/A')[:max_content_len]}{'...' if len(r.get('contents', '')) > max_content_len else ''}" + for i, r in enumerate(search_results) + ] + ) + formatted_str = f"{formatted}" + + return formatted_str + + def extract_search_query(self, text: str) -> str | None: + """Extract search query from text between tags.""" + pattern = re.compile(r"(.*?)", re.DOTALL) + matches = pattern.findall(text) + if matches: + query = matches[-1].strip() + return query + return None + + def extract_answer(self, text: str) -> str | None: + """Extract answer from text between tags.""" + pattern = re.compile(r"(.*?)", re.DOTALL) + matches = pattern.findall(text) + if matches: + answer = matches[-1].strip() + + return answer + + return None + + def run(self, dataset, do_eval: bool = True, pred_process_fun=None): + """Runs the ReSearch pipeline on the dataset using the Agent.""" + logger.info(f"πŸƒ Starting ReSearch pipeline run with {len(dataset)} items...") + + try: + questions = [item.question if hasattr(item, "question") else item["question"] for item in dataset] + + except (KeyError, AttributeError, TypeError) as e: + logger.error(f"Failed to extract questions from dataset items. Error: {e}", exc_info=True) + logger.error("Ensure dataset items have a 'question' key or attribute.") + return dataset + + agent_max_generations = getattr(self.config, "agent_max_generations", 20) + generator_max_output_len = getattr(self.config, "generator_max_output_len", 1024) + + try: + logger.info(f"πŸ€– Running agent inference for {len(questions)} questions...") + agent_outputs: AgenticOutputs = self.agent.run_agent( + generate_fn=self.generator.generate, + tokenizer=self.tokenizer, + questions=questions, + max_generations=agent_max_generations, + max_new_tokens=generator_max_output_len, + ) + final_responses = agent_outputs.final_response_str + logger.info(f"βœ… Agent inference completed. Received {len(final_responses)} final responses.") + + except Exception as e: + logger.error(f"Agent run failed during inference: {e}", exc_info=True) + logger.warning("Agent run failed, attempting evaluation with potentially incomplete results.") + for item in dataset: + if hasattr(item, "update_output"): + item.update_output("pred", "AGENT_ERROR") + elif isinstance(item, dict): + item["pred"] = "AGENT_ERROR" + + logger.info("πŸ“ Extracting answers and updating dataset items...") + num_updated = 0 + num_missing_answers = 0 + if len(final_responses) == len(dataset): + for i, item in enumerate(dataset): + response = final_responses[i] + answer = self.extract_answer(response) + pred_to_save = answer if answer is not None else "" + + if answer is None: + num_missing_answers += 1 + + if hasattr(item, "update_output"): + item.update_output("pred", pred_to_save) + item.update_output("final_response", response) + num_updated += 1 + elif isinstance(item, dict): + item["pred"] = pred_to_save + item["final_response"] = response + num_updated += 1 + else: + logger.warning(f"Item {i} has unknown type {type(item)}, cannot update with prediction.") + + logger.info(f"Updated {num_updated}/{len(dataset)} dataset items with predictions.") + if num_missing_answers > 0: + logger.warning(f"{num_missing_answers} items had no tag.") + else: + logger.error( + f"Mismatch between dataset size ({len(dataset)}) and number of agent responses ({len(final_responses)}). Cannot reliably update dataset." + ) + for item in dataset: + if hasattr(item, "update_output"): + item.update_output("pred", "RESPONSE_COUNT_MISMATCH") + elif isinstance(item, dict): + item["pred"] = "RESPONSE_COUNT_MISMATCH" + + if do_eval: + logger.info("πŸ“Š Evaluating results using BasicPipeline.evaluate...") + try: + dataset = self.evaluate(dataset, do_eval=True, pred_process_fun=pred_process_fun) + logger.info("βœ… Evaluation completed via base class method.") + except Exception as e: + logger.error(f"Error during BasicPipeline.evaluate: {e}", exc_info=True) + logger.warning("Evaluation may be incomplete.") + else: + logger.info("Skipping evaluation step as do_eval=False.") + + logger.info("βœ… ReSearch pipeline run finished.") + return dataset + + +class SGLRemoteGenerator(BaseGenerator): + """Class for decoder-only generator, based on SGLang remote service.""" + + def __init__(self, config: Config): + """Initializes the SGLRemoteGenerator.""" + super().__init__(config) + logger.info("πŸ”§ Initializing SGLRemoteGenerator...") + sgl_url = getattr(config, "sgl_remote_url", "localhost:8002") + self.sgl_remote_url = f"http://{sgl_url}/generate" + self.health_check_url = f"http://{sgl_url}/health" + logger.info(f"πŸ”— Remote Generator URL: {self.sgl_remote_url}") + self.model_path = getattr(config, "generator_model_path", None) + if not self.model_path: + logger.error("generator_model_path not found in config!") + raise ValueError("generator_model_path is required for SGLRemoteGenerator") + + try: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) + logger.info("βœ… Tokenizer loaded for generator.") + except Exception as e: + logger.error(f"Failed to load tokenizer for generator from {self.model_path}: {e}", exc_info=True) + raise + + self.generation_params = getattr(config, "generation_params", {}) + self.config = config + + self._check_health() + + def _check_health(self): + """Checks the health of the remote generator service.""" + try: + test_response = requests.get(self.health_check_url, timeout=5) + test_response.raise_for_status() + logger.info("βœ… Remote generator service is available") + except requests.exceptions.RequestException as e: + logger.error(f"Could not connect or verify remote generator service at {self.health_check_url}: {str(e)}") + logger.warning("Please ensure the SGLang service is running and accessible.") + + @retry(max_retries=5, sleep=2) + def generate( + self, + input_list: list[str] | str, + return_raw_output: bool = False, + return_scores: bool = False, + **params, + ) -> list[str] | tuple[list[str], list[list[float]]] | list[dict]: + """Generates text using the remote SGLang service.""" + if isinstance(input_list, str): + input_list = [input_list] + if not isinstance(input_list, list) or not all(isinstance(item, str) for item in input_list): + raise ValueError("Input must be a string or a list of strings.") + + batch_size = len(input_list) + data_to_remote = {"text": input_list} + + effective_params = deepcopy(self.generation_params) + effective_params.update(params) + + curr_sampling_params = {} + if effective_params.get("do_sample", True) is False: + curr_sampling_params["temperature"] = 0.0 + else: + curr_sampling_params["temperature"] = effective_params.get( + "temperature", getattr(self.config, "temperature", 0.7) + ) + + default_max_tokens = getattr(self.config, "generator_max_output_len", 1024) + curr_sampling_params["max_new_tokens"] = effective_params.get("max_new_tokens", default_max_tokens) + + stop_sequences = effective_params.get("stop", []) + if isinstance(stop_sequences, str): + stop_sequences = [stop_sequences] + if stop_sequences: + curr_sampling_params["stop"] = stop_sequences + + keys_to_remove = ["do_sample", "temperature", "max_new_tokens", "stop"] + for key in keys_to_remove: + effective_params.pop(key, None) + + if "top_p" in effective_params: + curr_sampling_params["top_p"] = effective_params["top_p"] + if "top_k" in effective_params: + curr_sampling_params["top_k"] = effective_params["top_k"] + + data_to_remote["sampling_params"] = curr_sampling_params + + if return_scores: + data_to_remote["return_logprob"] = True + data_to_remote["top_logprobs_num"] = getattr(self.config, "top_logprobs_num", 2) + + try: + response = requests.post( + self.sgl_remote_url, json=data_to_remote, timeout=120, headers={"Content-Type": "application/json"} + ) + response.raise_for_status() + + response_list = response.json() + + if return_raw_output: + return response_list + + generated_text = [] + for item in response_list: + text = item.get("text", "") + finish_reason = item.get("meta_info", {}).get("finish_reason", {}) + matched_stop = finish_reason.get("matched") + if matched_stop and curr_sampling_params.get("stop") and matched_stop in curr_sampling_params["stop"]: + text += matched_stop + generated_text.append(text) + + if return_scores: + scores = [] + for resp_item in response_list: + logprobs_list = resp_item.get("meta_info", {}).get("output_token_logprobs", []) + token_scores = [ + np.exp(logprob[0]) if (logprob and len(logprob) > 0) else 0.0 for logprob in logprobs_list + ] + scores.append(token_scores) + return generated_text, scores + else: + return generated_text + + except requests.exceptions.Timeout: + logger.error("Generation request timed out after 120 seconds.") + raise + except requests.exceptions.ConnectionError: + logger.error(f"Could not connect to remote generator service at {self.sgl_remote_url}.") + raise + except requests.exceptions.RequestException as e: + logger.error(f"Network error during generation: {str(e)}", exc_info=True) + raise + except json.JSONDecodeError as e: + response_text = "Unknown (error occurred before response object assignment)" + if "response" in locals() and hasattr(response, "text"): + response_text = response.text[:500] + logger.error( + f"Failed to decode JSON response from {self.sgl_remote_url}. Response text: {response_text}...", + exc_info=True, + ) + raise + except Exception as e: + logger.error(f"Unexpected error during generation: {str(e)}", exc_info=True) + raise + + +def load_dataset_items(config: Config, split: str) -> list[dict | object]: + """Loads dataset items using flashrag's get_dataset.""" + logger.info(f"πŸ“š Loading dataset: {config.dataset_name}, Split: {split}") + try: + all_splits = get_dataset(config) + if split not in all_splits: + logger.error( + f"Split '{split}' not found in dataset '{config.dataset_name}'. Available splits: {list(all_splits.keys())}" + ) + return [] + dataset_items = all_splits[split] + logger.info(f"Successfully loaded {len(dataset_items)} items for split '{split}'.") + + return dataset_items + except FileNotFoundError: + logger.error( + f"Dataset files not found for '{config.dataset_name}' in '{config.data_dir}'. Check config and paths." + ) + return [] + except Exception as e: + logger.error(f"Error loading dataset using get_dataset: {e}", exc_info=True) + return [] + + +def save_results(args: argparse.Namespace, config: Config, result_dataset, run_duration: float): + """Saves summary and debug information.""" + logger.info("πŸ’Ύ Saving results...") + summary_file = os.path.join(args.save_dir, f"{args.save_note}_summary.txt") + debug_file = os.path.join(args.save_dir, f"{args.save_note}_debug.json") + + num_items = len(result_dataset) + + logger.info(f"Saving summary results to {summary_file}...") + try: + with open(summary_file, "w", encoding="utf-8") as f: + f.write("EVALUATION SUMMARY\n") + f.write("=================\n\n") + f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + f.write(f"Run Duration: {run_duration:.2f} seconds\n") + f.write(f"Dataset: {config.dataset_name} ({args.split} split)\n") + f.write(f"Model: {config.generator_model_path}\n") + f.write(f"Retriever: {config.remote_retriever_url}\n") + f.write(f"Agent Max Generations: {getattr(config, 'agent_max_generations', 'N/A')}\n") + f.write(f"Generator Max Output Len: {getattr(config, 'generator_max_output_len', 'N/A')}\n\n") + f.write(f"Total items processed: {num_items}\n") + f.write("\nNote: Verification was skipped in this run.\n") + f.write("Note: Overall metrics (like EM, F1) are usually printed to console by evaluate method.\n") + + logger.info(f"βœ… Summary saved to {summary_file}") + except Exception as e: + logger.error(f"Error saving summary file '{summary_file}': {e}", exc_info=True) + + logger.info(f"Saving debug information (predictions & responses) to {debug_file}...") + try: + debug_data = [] + for i, item in enumerate(result_dataset): + item_data: dict[str, object] = {} + + def get_item_value(data_item, key_or_attr: str) -> str | int | float | list | bool | None: + if isinstance(data_item, dict): + return data_item.get(key_or_attr) + elif hasattr(data_item, key_or_attr): + return getattr(data_item, key_or_attr) + return None + + item_data["item_index"] = i + item_data["question"] = get_item_value(item, "question") + item_data["prediction"] = get_item_value(item, "pred") + item_data["final_response"] = get_item_value(item, "final_response") + + gt_answer_val = None + try: + gt_answer_val = get_item_value(item, "answer") + if gt_answer_val is None: + answers_list = get_item_value(item, "answers") + if isinstance(answers_list, list) and answers_list: + raw_ans = answers_list[0] + if isinstance(raw_ans, (str, int, float, bool)): + gt_answer_val = raw_ans + else: + gt_answer_val = str(raw_ans) + elif not isinstance(gt_answer_val, (str, int, float, bool)): + gt_answer_val = str(gt_answer_val) + except Exception as e: + logger.warning(f"Could not safely get ground truth for item {i}: {e}") + gt_answer_val = "ERROR_GETTING_ANSWER" + item_data["ground_truth"] = gt_answer_val + + eval_score_val = None + try: + eval_score_val = get_item_value(item, "score") + if not isinstance(eval_score_val, (str, int, float, bool, type(None))): + eval_score_val = str(eval_score_val) + except Exception as e: + logger.warning(f"Could not safely get score for item {i}: {e}") + eval_score_val = "ERROR_GETTING_SCORE" + item_data["eval_score"] = eval_score_val + + debug_data.append(item_data) + + with open(debug_file, "w", encoding="utf-8") as f: + json.dump(debug_data, f, indent=2, ensure_ascii=False) + logger.info(f"βœ… Debug information saved to {debug_file}") + except Exception as e: + logger.error(f"Error saving debug file '{debug_file}': {e}", exc_info=True) + + +def research(args: argparse.Namespace, config: Config): + """Main function to run the research evaluation pipeline.""" + logger.info("πŸš€ Starting research pipeline execution...") + start_time = time.time() + + test_data = load_dataset_items(config, args.split) + if not test_data: + logger.error("Failed to load test data. Exiting.") + return + + try: + logger.info("πŸ—οΈ Building ReSearchPipeline...") + pipeline = ReSearchPipeline(config) + logger.info("βœ… Pipeline built successfully.") + except Exception as e: + logger.error(f"Failed to initialize ReSearchPipeline: {e}", exc_info=True) + return + + agent_max_generations = getattr(config, "agent_max_generations", 20) + generator_max_output_len = getattr(config, "generator_max_output_len", 1024) + + try: + logger.info("πŸƒ Starting pipeline run...") + result_dataset = pipeline.run(test_data, do_eval=True) + logger.info("βœ… Pipeline run completed.") + except Exception as e: + logger.error(f"Error during pipeline run: {e}", exc_info=True) + result_dataset = test_data + logger.warning("Pipeline run failed, attempting to save inputs/partial results.") + + run_duration = time.time() - start_time + logger.info(f"Total run duration: {run_duration:.2f} seconds.") + save_results(args, config, result_dataset, run_duration) + + logger.info("🏁 Research pipeline execution finished.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Running ReSearch Evaluation Pipeline") + parser.add_argument( + "--config_path", type=str, default="./eval_config.yaml", help="Path to the main FlashRAG config file." + ) + parser.add_argument( + "--dataset_name", + type=str, + default="bamboogle", + help="Name of the dataset (must match config or data_dir structure).", + ) + parser.add_argument( + "--split", type=str, default="test", help="Dataset split to evaluate (e.g., test, validation)." + ) + parser.add_argument("--save_dir", type=str, default="./output_logs", help="Directory to save logs and results.") + parser.add_argument("--save_note", type=str, default="research_run", help="A note to prepend to saved filenames.") + + parser.add_argument("--data_dir", type=str, help="Override data directory specified in config.") + parser.add_argument( + "--sgl_remote_url", type=str, help="Override SGLang remote generator URL (e.g., localhost:8002)." + ) + parser.add_argument( + "--remote_retriever_url", type=str, help="Override remote retriever URL (e.g., localhost:8001)." + ) + parser.add_argument("--generator_model_path", type=str, help="Override generator model path specified in config.") + parser.add_argument("--retriever_topk", type=int, help="Override retriever top K.") + parser.add_argument("--generator_max_output_len", type=int, help="Override generator max output length.") + parser.add_argument("--agent_max_generations", type=int, help="Override agent max interaction turns.") + + args = parser.parse_args() + logger.info(f"Starting evaluation script with arguments: {args}") + + try: + os.makedirs(args.save_dir, exist_ok=True) + logger.info(f"πŸ’Ύ Logs and results will be saved to: {args.save_dir}") + except OSError as e: + logger.error(f"Could not create save directory '{args.save_dir}': {e}", exc_info=True) + exit(1) + + config_overrides = { + k: v + for k, v in vars(args).items() + if v is not None + and k + not in [ + "config_path", + "dataset_name", + "split", + "save_dir", + "save_note", + ] + } + + logger.info(f"πŸ”§ Loading configuration from: {args.config_path}") + try: + config = Config(args.config_path, config_dict=config_overrides) + config.dataset_name = args.dataset_name + if args.data_dir: + config.data_dir = args.data_dir + + logger.info(f"Effective data_dir: {getattr(config, 'data_dir', 'N/A')}") + logger.info(f"Effective generator_model_path: {getattr(config, 'generator_model_path', 'N/A')}") + logger.info(f"Effective sgl_remote_url: {getattr(config, 'sgl_remote_url', 'N/A')}") + logger.info(f"Effective remote_retriever_url: {getattr(config, 'remote_retriever_url', 'N/A')}") + + logger.info("βœ… Config loaded and potentially overridden by CLI arguments.") + + config["dataset_path"] = os.path.join(config.data_dir, config.dataset_name) + + except FileNotFoundError: + logger.error(f"Config file not found at '{args.config_path}'. Please check the path.") + exit(1) + except Exception as e: + logger.error(f"Error loading or processing configuration: {e}", exc_info=True) + exit(1) + + research(args, config) diff --git a/scripts/serving/download_flashrag_datasets.py b/scripts/serving/download_flashrag_datasets.py new file mode 100644 index 0000000..53b58f3 --- /dev/null +++ b/scripts/serving/download_flashrag_datasets.py @@ -0,0 +1,68 @@ +import argparse +import os +import zipfile + +from dotenv import load_dotenv +from huggingface_hub import snapshot_download + +from config import DATA_DIR + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments. + + Returns: + argparse.Namespace: Parsed arguments + """ + parser = argparse.ArgumentParser(description="Download FlashRAG datasets from HuggingFace Hub") + parser.add_argument( + "--repo-id", + type=str, + default="RUC-NLPIR/FlashRAG_datasets", + help="HuggingFace repository IDs", + ) + parser.add_argument( + "--local-dir", + type=str, + default=DATA_DIR / "flashrag_datasets", + help="Local directory to save model", + ) + + return parser.parse_args() + + +def main(): + """Main function to download model.""" + args = parse_args() + load_dotenv(override=True) + + # Configuration + HF_TOKEN = os.getenv("HF_TOKEN") + + ALLOW_PATTERNS = [ + "*retrieval-corpus*", + "*bamboogle*", + "*nq*", + ] + + # Download the model + snapshot_download( + token=HF_TOKEN, + repo_id=args.repo_id, + local_dir=args.local_dir, + repo_type="dataset", + # ignore_patterns=IGNORE_PATTERNS, + allow_patterns=ALLOW_PATTERNS, + ) + + # unzip data/flashrag_datasets/retrieval-corpus/wiki18_100w.zip + print("Unzipping wiki18_100w.zip. Might take a while...") + zip_file_path = os.path.join(args.local_dir, "retrieval-corpus", "wiki18_100w.zip") + with zipfile.ZipFile(zip_file_path, "r") as zip_ref: + zip_ref.extractall(args.local_dir) + + print(f"βœ… Done: {args.repo_id} -> {args.local_dir}") + + +if __name__ == "__main__": + main() diff --git a/scripts/serving/download_flashrag_index.py b/scripts/serving/download_flashrag_index.py new file mode 100644 index 0000000..efc1e8e --- /dev/null +++ b/scripts/serving/download_flashrag_index.py @@ -0,0 +1,43 @@ +"""Download and extract FlashRAG index.""" + +import os +import zipfile + +import requests +from tqdm import tqdm + +from config import DATA_DIR + +# Constants +URL = "https://www.modelscope.cn/datasets/hhjinjiajie/FlashRAG_Dataset/resolve/master/retrieval_corpus/wiki18_100w_e5_index.zip" +ZIP_NAME = "wiki18_100w_e5_index.zip" +zip_path = DATA_DIR / ZIP_NAME + +# Download with progress bar +print("πŸ“₯ Downloading index...") +response = requests.get(URL, stream=True) +total_size = int(response.headers.get("content-length", 0)) + +with ( + open(zip_path, "wb") as f, + tqdm( + desc=ZIP_NAME, + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as bar, +): + for data in response.iter_content(chunk_size=1024): + size = f.write(data) + bar.update(size) + +# Extract +print("πŸ“¦ Extracting index...") +with zipfile.ZipFile(zip_path, "r") as zip_ref: + zip_ref.extractall(DATA_DIR) + +# Clean up zip +os.remove(zip_path) +print("βœ… Download and extraction completed successfully!") +print(f"Index file is at: {DATA_DIR}/data00/jiajie_jin/flashrag_indexes/wiki_dpr_100w/e5_flat_inner.index") diff --git a/scripts/serving/download_generator_model.py b/scripts/serving/download_generator_model.py new file mode 100644 index 0000000..27689ac --- /dev/null +++ b/scripts/serving/download_generator_model.py @@ -0,0 +1,54 @@ +import argparse +import os + +from dotenv import load_dotenv +from huggingface_hub import snapshot_download + +from config import GENERATOR_MODEL_DIR, GENERATOR_MODEL_REPO_ID + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments. + + Returns: + argparse.Namespace: Parsed arguments + """ + parser = argparse.ArgumentParser(description="Download model from HuggingFace Hub") + parser.add_argument( + "--repo-id", + type=str, + default=GENERATOR_MODEL_REPO_ID, + help="HuggingFace repository ID", + ) + parser.add_argument( + "--local-dir", + type=str, + default=GENERATOR_MODEL_DIR, + help="Local directory to save model", + ) + + return parser.parse_args() + + +def main(): + """Main function to download model.""" + args = parse_args() + load_dotenv(override=True) + + # Configuration + HF_TOKEN = os.getenv("HF_TOKEN") + + print("Downloading model to", args.local_dir) + + # Download the model + snapshot_download( + token=HF_TOKEN, + repo_id=args.repo_id, + local_dir=args.local_dir, + repo_type="model", + ) + print(f"βœ… Done: {args.repo_id} -> {args.local_dir}") + + +if __name__ == "__main__": + main() diff --git a/scripts/serving/download_retriever_model.py b/scripts/serving/download_retriever_model.py new file mode 100644 index 0000000..28bd4db --- /dev/null +++ b/scripts/serving/download_retriever_model.py @@ -0,0 +1,54 @@ +import argparse +import os + +from dotenv import load_dotenv +from huggingface_hub import snapshot_download + +from config import RETRIEVER_MODEL_DIR, RETRIEVER_MODEL_REPO_ID + + +def parse_args() -> argparse.Namespace: + """Parse command line arguments. + + Returns: + argparse.Namespace: Parsed arguments + """ + parser = argparse.ArgumentParser(description="Download model from HuggingFace Hub") + parser.add_argument( + "--repo-id", + type=str, + default=RETRIEVER_MODEL_REPO_ID, + help="HuggingFace repository ID", + ) + parser.add_argument( + "--local-dir", + type=str, + default=RETRIEVER_MODEL_DIR, + help="Local directory to save model", + ) + + return parser.parse_args() + + +def main(): + """Main function to download model.""" + args = parse_args() + load_dotenv(override=True) + + # Configuration + HF_TOKEN = os.getenv("HF_TOKEN") + + print("Downloading model to", args.local_dir) + + # Download the model + snapshot_download( + token=HF_TOKEN, + repo_id=args.repo_id, + local_dir=args.local_dir, + repo_type="model", + ) + print(f"βœ… Done: {args.repo_id} -> {args.local_dir}") + + +if __name__ == "__main__": + main() diff --git a/scripts/serving/retriever_config.yaml b/scripts/serving/retriever_config.yaml new file mode 100644 index 0000000..fe5dd1a --- /dev/null +++ b/scripts/serving/retriever_config.yaml @@ -0,0 +1,9 @@ +# ------------------------------------------------Environment Settings------------------------------------------------# +gpu_id: "0" + +# -------------------------------------------------Retrieval Settings------------------------------------------------# +# If set the name, the model path will be find in global paths +retrieval_method: "e5" # name or path of the retrieval model. +index_path: "/mnt/nas/thinhlpg/data/data00/jiajie_jin/flashrag_indexes/wiki_dpr_100w/e5_flat_inner.index" # path to the indexed file +faiss_gpu: False # whether use gpu to hold index +corpus_path: "/mnt/nas/thinhlpg/code/DeepSearch/data/flashrag_datasets/wiki18_100w.jsonl" # path to corpus in '.jsonl' format that store the documents diff --git a/scripts/serving/serve_generator.py b/scripts/serving/serve_generator.py new file mode 100644 index 0000000..6901929 --- /dev/null +++ b/scripts/serving/serve_generator.py @@ -0,0 +1,125 @@ +import subprocess +import sys +from pathlib import Path + +# Add project root to sys.path to allow importing config +# Assuming the script is at DeepSearch/scripts/serving/serve_generator.py +# The project root (DeepSearch) is parents[2] +PROJ_ROOT = Path(__file__).resolve().parents[2] +if str(PROJ_ROOT) not in sys.path: + sys.path.append(str(PROJ_ROOT)) + +# Import after adjusting sys.path +try: + from config import ( + GENERATOR_MODEL_REPO_ID, + GENERATOR_SERVER_PORT, + MODEL_CONFIG, + logger, + ) +except ImportError as e: + # Use print here as logger might not be available if import failed + print( + f"Error importing config: {e}. Make sure config.py is in the project root ({PROJ_ROOT}) and added to sys.path." + ) + sys.exit(1) + + +def launch_sglang_server( + model_id: str, + port: int, + context_length: int, + host: str = "0.0.0.0", + dtype: str = "bfloat16", +) -> None: + """Launches the SGLang server using specified configurations. + + Args: + model_id: The Hugging Face repository ID of the model. + port: The port number for the server. + context_length: The maximum context length for the model. + host: The host address for the server. + dtype: The data type for the model (e.g., 'bfloat16', 'float16'). + """ + command = [ + sys.executable, # Use the current Python interpreter + "-m", + "sglang.launch_server", + "--model-path", + model_id, + "--context-length", + str(context_length), + "--enable-metrics", + "--dtype", + dtype, + "--host", + host, + "--port", + str(port), + "--trust-remote-code", + # Recommended by SGLang for stability sometimes + "--disable-overlap", + # Can sometimes cause issues + "--disable-radix-cache", + ] + + # Log the command clearly + command_str = " ".join(command) + logger.info(f"πŸš€ Launching SGLang server with command: {command_str}") + + process = None # Initialize process to None + try: + # Use Popen to start the server process + # It runs in the foreground relative to this script, + # but allows us to catch KeyboardInterrupt cleanly. + process = subprocess.Popen(command) + # Wait for the process to complete (e.g., user interruption) + process.wait() + # Check return code after waiting + if process.returncode != 0: + logger.error(f"πŸ’₯ SGLang server process exited with error code: {process.returncode}") + sys.exit(process.returncode) + else: + logger.info("βœ… SGLang server process finished gracefully.") + + except FileNotFoundError: + logger.error(f"πŸ’₯ Error: Python executable or sglang module not found.") + logger.error(f"Ensure '{sys.executable}' is correct and sglang is installed.") + sys.exit(1) + except KeyboardInterrupt: + logger.info("πŸ›‘ SGLang server launch interrupted by user. Stopping server...") + # Attempt to terminate the process gracefully + if process and process.poll() is None: # Check if process exists and is running + process.terminate() + try: + process.wait(timeout=5) # Wait a bit for termination + logger.info("βœ… Server terminated gracefully.") + except subprocess.TimeoutExpired: + logger.warning("⚠️ Server did not terminate gracefully, forcing kill.") + process.kill() + sys.exit(0) # Exit cleanly after interrupt + except Exception as e: + # Catch any other unexpected exceptions during launch or waiting + logger.error(f"🚨 An unexpected error occurred: {e}") + # Ensure process is cleaned up if it exists + if process and process.poll() is None: + process.kill() + sys.exit(1) + + +if __name__ == "__main__": + # Get context length from config, default to 8192 + context_len = MODEL_CONFIG.get("max_seq_length", 8192) + + logger.info("----------------------------------------------------") + logger.info("✨ Starting SGLang Generator Server ✨") + logger.info(f" Model ID: {GENERATOR_MODEL_REPO_ID}") + logger.info(f" Port: {GENERATOR_SERVER_PORT}") + logger.info(f" Context Length: {context_len}") + logger.info("----------------------------------------------------") + + launch_sglang_server( + model_id=GENERATOR_MODEL_REPO_ID, + port=GENERATOR_SERVER_PORT, + context_length=context_len, + ) diff --git a/scripts/serving/serve_retriever.py b/scripts/serving/serve_retriever.py new file mode 100644 index 0000000..229351c --- /dev/null +++ b/scripts/serving/serve_retriever.py @@ -0,0 +1,113 @@ +import argparse +import asyncio +from collections import deque +from typing import List, Tuple, Union + +from fastapi import FastAPI, HTTPException +from flashrag.config import Config +from flashrag.utils import get_retriever +from pydantic import BaseModel + +from config import RETRIEVER_SERVER_PORT + +app = FastAPI() + +retriever_list = [] +available_retrievers = deque() +retriever_semaphore = None + + +def init_retriever(args): + global retriever_semaphore + config = Config(args.config) + for i in range(args.num_retriever): + print(f"Initializing retriever {i + 1}/{args.num_retriever}") + retriever = get_retriever(config) + retriever_list.append(retriever) + available_retrievers.append(i) + # create a semaphore to limit the number of retrievers that can be used concurrently + retriever_semaphore = asyncio.Semaphore(args.num_retriever) + + +@app.get("/health") +async def health_check(): + return {"status": "healthy", "retrievers": {"total": len(retriever_list), "available": len(available_retrievers)}} + + +class QueryRequest(BaseModel): + query: str + top_n: int = 10 + return_score: bool = False + + +class BatchQueryRequest(BaseModel): + query: List[str] + top_n: int = 10 + return_score: bool = False + + +class Document(BaseModel): + id: str + contents: str + + +@app.post("/search", response_model=Union[Tuple[List[Document], List[float]], List[Document]]) +async def search(request: QueryRequest): + query = request.query + top_n = request.top_n + return_score = request.return_score + + if not query or not query.strip(): + print(f"Query content cannot be empty: {query}") + raise HTTPException(status_code=400, detail="Query content cannot be empty") + + async with retriever_semaphore: + retriever_idx = available_retrievers.popleft() + try: + if return_score: + results, scores = retriever_list[retriever_idx].search(query, top_n, return_score) + return [Document(id=result["id"], contents=result["contents"]) for result in results], scores + else: + results = retriever_list[retriever_idx].search(query, top_n, return_score) + return [Document(id=result["id"], contents=result["contents"]) for result in results] + finally: + available_retrievers.append(retriever_idx) + + +@app.post("/batch_search", response_model=Union[List[List[Document]], Tuple[List[List[Document]], List[List[float]]]]) +async def batch_search(request: BatchQueryRequest): + query = request.query + top_n = request.top_n + return_score = request.return_score + + async with retriever_semaphore: + retriever_idx = available_retrievers.popleft() + try: + if return_score: + results, scores = retriever_list[retriever_idx].batch_search(query, top_n, return_score) + return [ + [Document(id=result["id"], contents=result["contents"]) for result in results[i]] + for i in range(len(results)) + ], scores + else: + results = retriever_list[retriever_idx].batch_search(query, top_n, return_score) + return [ + [Document(id=result["id"], contents=result["contents"]) for result in results[i]] + for i in range(len(results)) + ] + finally: + available_retrievers.append(retriever_idx) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="./retriever_config.yaml") + parser.add_argument("--num_retriever", type=int, default=1) + parser.add_argument("--port", type=int, default=RETRIEVER_SERVER_PORT) + args = parser.parse_args() + + init_retriever(args) + + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=args.port)