From 89e07bc02dc6c7ceff74e43646131273d210e63c Mon Sep 17 00:00:00 2001 From: thinhlpg Date: Tue, 15 Apr 2025 05:52:35 +0000 Subject: [PATCH] chore: chore: remove unused code and dependencies --- .env.example | 3 +- .gitignore | 1 + .gitmodules | 3 - Makefile | 185 +---- scripts/evaluation/eval_config.yaml | 46 -- scripts/evaluation/run_eval.py | 656 ------------------ scripts/serving/download_flashrag_datasets.py | 68 -- scripts/serving/download_flashrag_index.py | 43 -- scripts/serving/download_generator_model.py | 54 -- scripts/serving/download_retriever_model.py | 54 -- scripts/serving/retriever_config.yaml | 9 - scripts/serving/serve_generator.py | 127 ---- scripts/serving/serve_retriever.py | 113 --- scripts/train_data/build_musique_index.py | 135 ---- scripts/train_data/download_data_musique.sh | 30 - .../train_data/extract_musique_paragraphs.py | 101 --- .../train_data/prepare_musique_dev_jsonl.py | 155 ----- scripts/train_data/prepare_musique_jsonl.py | 172 ----- third_party/FlashRAG | 1 - 19 files changed, 31 insertions(+), 1925 deletions(-) delete mode 100644 .gitmodules delete mode 100644 scripts/evaluation/eval_config.yaml delete mode 100644 scripts/evaluation/run_eval.py delete mode 100644 scripts/serving/download_flashrag_datasets.py delete mode 100644 scripts/serving/download_flashrag_index.py delete mode 100644 scripts/serving/download_generator_model.py delete mode 100644 scripts/serving/download_retriever_model.py delete mode 100644 scripts/serving/retriever_config.yaml delete mode 100644 scripts/serving/serve_generator.py delete mode 100644 scripts/serving/serve_retriever.py delete mode 100644 scripts/train_data/build_musique_index.py delete mode 100644 scripts/train_data/download_data_musique.sh delete mode 100644 scripts/train_data/extract_musique_paragraphs.py delete mode 100644 scripts/train_data/prepare_musique_dev_jsonl.py delete mode 100644 scripts/train_data/prepare_musique_jsonl.py delete mode 160000 third_party/FlashRAG diff --git a/.env.example b/.env.example index a9fa6d4..ec24d52 100644 --- a/.env.example +++ b/.env.example @@ -1,2 +1,3 @@ HF_TOKEN= -OPENROUTER_API_KEY= \ No newline at end of file +TAVILY_API_KEY= +SERPER_API_KEY= \ No newline at end of file diff --git a/.gitignore b/.gitignore index be40a92..ceed6b6 100644 --- a/.gitignore +++ b/.gitignore @@ -19,6 +19,7 @@ logs/ data/ .gradio/ output* +llama.cpp* # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index b45d826..0000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "third_party/FlashRAG"] - path = third_party/FlashRAG - url = https://github.com/RUC-NLPIR/FlashRAG.git diff --git a/Makefile b/Makefile index 9a98575..625143c 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: style quality install tensorboard clean fix update-worklog test data download-musique prepare-musique-jsonl extract-musique-paragraphs build-musique-index prepare-musique-index prepare-all-musique check-data prepare-dev-data ensure-unzip download-all-models serve-retriever serve-generator run-evaluation download-flashrag-data download-flashrag-index download-retriever-model download-generator-model serve-all run-full-evaluation evaluation-download-models prepare-serving serve-background stop-serving +.PHONY: style quality install clean fix test check-data simple-qa generate-data eval-base eval-lora download-checkpoint upload-checkpoint save-merged-16bit # make sure to test the local checkout in scripts and not the pre-installed one export PYTHONPATH = src @@ -11,7 +11,7 @@ test: # Development dependencies install: - pip install -e . && pip install -e third_party/FlashRAG + pip install -e . # Code quality and style style: @@ -28,159 +28,38 @@ fix: ruff check --fix --line-length 119 --target-version py311 $(check_dirs) isort $(check_dirs) -# TensorBoard -tensorboard: - tensorboard --logdir=trainer_output_*_runs --port=6006 - -# List available run directories -list-runs: - @echo "Available run directories:" - @ls -d trainer_output_*_runs 2>/dev/null || echo "No run directories found" - -# Ensure unzip is available -ensure-unzip: - @which unzip > /dev/null || (echo "Installing unzip..." && sudo apt-get update && sudo apt-get install -y unzip) - @echo "✓ unzip is available" - -# Data Preparation - One command to rule them all -data: download-musique prepare-musique-jsonl extract-musique-paragraphs build-musique-index prepare-dev-data check-data - @echo "✨ All data preparation complete! ✨" - -# Index Preparation -prepare-musique-index: build-musique-index - @echo "Musique index preparation complete." +# Check Data +check-data: + @echo "Checking generated data files..." + python scripts/check_data.py -download-musique: ensure-unzip - @echo "Downloading Musique dataset..." - bash scripts/train_data/download_data_musique.sh - @echo "Musique dataset ready in ./data/raw/" +# Simple QA +simple-qa: + python scripts/simple_qa.py -prepare-musique-jsonl: download-musique - @echo "Preparing Musique data (JSONL)..." - python scripts/train_data/prepare_musique_jsonl.py - @echo "Processed Musique JSONL ready in ./data/processed/questions.jsonl" +# Generate data +generate-data: + python scripts/generate_data.py -extract-musique-paragraphs: download-musique - @echo "Extracting unique paragraphs from raw Musique data..." - python scripts/train_data/extract_musique_paragraphs.py - @echo "Musique paragraphs extracted to ./data/processed/paragraphs.csv" +# Evaluate base model +eval-base: + python scripts/eval_base.py -build-musique-index: extract-musique-paragraphs - @echo "Building Musique FAISS index from paragraphs..." - python scripts/train_data/build_musique_index.py - @echo "Musique FAISS index files saved to ./data/processed/" +# Evaluate LoRA model +eval-lora: + python scripts/eval_lora.py -# Combined Preparation -prepare-all-musique: data prepare-musique-index - @echo "All Musique data and index preparation complete." +# Download checkpoint +download-checkpoint: + python scripts/download_checkpoint.py -# Check Data -check-data: - @echo "Checking generated data files..." - python scripts/check_data.py +# Upload checkpoint +upload-checkpoint: + python scripts/upload_checkpoint.py -# Prepare Dev Data -prepare-dev-data: download-musique - @echo "Preparing Musique DEV data (JSONL)..." - python scripts/train_data/prepare_musique_dev_jsonl.py - @echo "Processed Musique DEV JSONL ready in ./data/processed/questions_dev.jsonl" - -# ======= SERVING COMMANDS ======= - -# Prepare everything needed for serving (download models and data) -prepare-serving: download-all-models - @echo "✨ All models and data for serving prepared! ✨" - @echo "You can now run services with:" - @echo " make serve-retriever" - @echo " make serve-generator" - @echo " or both with separate terminals" - -# Download all required models and data for serving -download-all-models: download-flashrag-data download-flashrag-index download-retriever-model download-generator-model - @echo "✨ All models and data downloaded! ✨" - -# Download FlashRAG datasets -download-flashrag-data: - @echo "Downloading FlashRAG datasets..." - python scripts/serving/download_flashrag_datasets.py - @echo "FlashRAG datasets downloaded!" - -# Download FlashRAG index -download-flashrag-index: - @echo "Downloading FlashRAG index..." - python scripts/serving/download_flashrag_index.py - @echo "FlashRAG index downloaded!" - -# Download retriever model -download-retriever-model: - @echo "Downloading retriever model..." - python scripts/serving/download_retriever_model.py - @echo "Retriever model downloaded!" - -# Download generator model -download-generator-model: - @echo "Downloading generator model..." - python scripts/serving/download_generator_model.py - @echo "Generator model downloaded!" - -# Serve retriever -serve-retriever: download-retriever-model download-flashrag-index download-flashrag-data - @echo "Starting retriever service..." - python scripts/serving/serve_retriever.py --config scripts/serving/retriever_config.yaml - -# Serve generator -serve-generator: download-generator-model - @echo "Starting generator service..." - python scripts/serving/serve_generator.py - -# Start both services (retriever and generator) in the background -serve-background: prepare-serving - @echo "Starting both retriever and generator services in background..." - @mkdir -p logs - @echo "Starting retriever in background..." - @nohup python scripts/serving/serve_retriever.py --config scripts/serving/retriever_config.yaml > logs/retriever.log 2>&1 & - @echo "Retriever started! PID: $$!" - @echo "Starting generator in background..." - @nohup python scripts/serving/serve_generator.py > logs/generator.log 2>&1 & - @echo "Generator started! PID: $$!" - @echo "✨ Both services running in background! ✨" - @echo "Check logs in logs/retriever.log and logs/generator.log" - @echo "To stop services: make stop-serving" - -# Stop all serving processes -stop-serving: - @echo "Stopping all serving processes..." - @pkill -f 'python scripts/serving/serve_' || echo "No serving processes found" - @echo "✅ All services stopped!" - -# Serve all components -serve-all: download-all-models - @echo "Starting all services..." - @echo "Please run these commands in separate terminals:" - @echo " make serve-retriever" - @echo " make serve-generator" - @echo "" - @echo "Or run both in background with one command:" - @echo " make serve-background" - @echo "" - @echo "To stop background services:" - @echo " make stop-serving" - -# ======= EVALUATION COMMANDS ======= - -# Download models needed for evaluation -evaluation-download-models: download-all-models - @echo "✨ All models for evaluation downloaded! ✨" - -# Run evaluation script -run-evaluation: - @echo "Running evaluation..." - python scripts/evaluation/run_eval.py --config scripts/evaluation/eval_config.yaml - @echo "Evaluation complete! Results in scripts/evaluation/output_logs/" - -# Run complete evaluation pipeline -run-full-evaluation: evaluation-download-models run-evaluation - @echo "✨ Full evaluation pipeline complete! ✨" +# Save merged 16bit model +save-merged-16bit: + python scripts/save_merged_16bit.py # Clean up clean: @@ -196,12 +75,4 @@ clean: find . -type d -name ".coverage" -exec rm -r {} + find . -type d -name "htmlcov" -exec rm -r {} + find . -type d -name "build" -exec rm -r {} + - find . -type d -name "dist" -exec rm -r {} + - rm -rf ./data/raw ./data/processed # Clean raw and processed data - # Clean up the old faiss_index directory if it exists - rm -rf ./data/processed/faiss_index - -# Update worklog in GitHub issue -update-worklog: - gh api -X PATCH /repos/menloresearch/DeepSearch/issues/comments/2743047160 \ - -f body="$$(cat docs/00_worklog.md)" | cat && kill -9 $$PPID \ No newline at end of file + find . -type d -name "dist" -exec rm -r {} + \ No newline at end of file diff --git a/scripts/evaluation/eval_config.yaml b/scripts/evaluation/eval_config.yaml deleted file mode 100644 index ad503d6..0000000 --- a/scripts/evaluation/eval_config.yaml +++ /dev/null @@ -1,46 +0,0 @@ -# ------------------------------------------------Environment Settings------------------------------------------------# -# Directory paths for data and outputs -data_dir: "/mnt/nas/thinhlpg/code/DeepSearch/data/flashrag_datasets/" -save_dir: "/mnt/nas/thinhlpg/code/DeepSearch/logs" - -# Seed for reproducibility -seed: 2024 - -# Whether save intermediate data -save_intermediate_data: True -save_note: 'experiment' - -# -------------------------------------------------Retrieval Settings------------------------------------------------# -# If set the remote url, the retriever will be a remote retriever and ignore following settings -use_remote_retriever: True -remote_retriever_url: "localhost:8001" - -instruction: ~ # instruction for retrieval model -retrieval_topk: 5 # number of retrieved documents -retrieval_batch_size: 256 # batch size for retrieval -retrieval_use_fp16: True # whether to use fp16 for retrieval model -retrieval_query_max_length: 128 # max length of the query -save_retrieval_cache: False # whether to save the retrieval cache -use_retrieval_cache: False # whether to use the retrieval cache -retrieval_cache_path: ~ # path to the retrieval cache -retrieval_pooling_method: ~ # set automatically if not provided - -# -------------------------------------------------Generator Settings------------------------------------------------# -framework: sgl_remote # inference frame work of LLM, supporting: 'hf','vllm','fschat' -sgl_remote_url: "localhost:8002" -generator_model: "janhq/250404-llama-3.2-3b-instruct-grpo-03-s250" # name or path of the generator model, for laoding tokenizer -generator_max_input_len: 2048 # max length of the input -generation_params: - do_sample: False - max_tokens: 8192 - -# -------------------------------------------------Evaluation Settings------------------------------------------------# -# Metrics to evaluate the result -metrics: [ 'em','f1','acc','precision','recall'] -# Specify setting for metric, will be called within certain metrics -metric_setting: - retrieval_recall_topk: 5 -save_metric_score: True # whether to save the metric score into txt file - - - diff --git a/scripts/evaluation/run_eval.py b/scripts/evaluation/run_eval.py deleted file mode 100644 index 7b6810d..0000000 --- a/scripts/evaluation/run_eval.py +++ /dev/null @@ -1,656 +0,0 @@ -import argparse -import json -import os -import re -import time -from copy import deepcopy -from datetime import datetime -from functools import wraps - -import numpy as np -import requests -from flashrag.config import Config -from flashrag.generator.generator import BaseGenerator -from flashrag.pipeline import BasicPipeline -from flashrag.retriever.retriever import BaseTextRetriever -from flashrag.utils import get_dataset -from transformers import AutoTokenizer - -from config import logger -from src.agent import Agent, AgenticOutputs -from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter - - -def retry(max_retries=10, sleep=1): - """Decorator to retry a function with exponential backoff.""" - - def decorator(func): - @wraps(func) - def wrapper(*args, **kwargs): - func_name = func.__name__ - for attempt in range(max_retries): - try: - result = func(*args, **kwargs) - return result - except Exception as e: - logger.warning(f"Attempt {attempt + 1} of {func_name} failed: {e}") - if attempt == max_retries - 1: - logger.error(f"Function {func_name} failed after {max_retries} retries.", exc_info=True) - raise e - backoff_time = sleep * (2**attempt) - logger.info(f"Retrying {func_name} in {backoff_time:.2f} seconds...") - time.sleep(backoff_time) - logger.error(f"Function {func_name} retry logic finished unexpectedly.") - return None - - return wrapper - - return decorator - - -class RemoteRetriever(BaseTextRetriever): - """A wrapper for remote retriever service with retry logic and logging.""" - - def __init__(self, config: Config): - """Initializes the RemoteRetriever.""" - super().__init__(config) - self.remote_url = f"http://{getattr(config, 'remote_retriever_url', 'localhost:8001')}" - self.topk = getattr(config, "retriever_topk", 5) - logger.info(f"🔗 Remote retriever URL: {self.remote_url}") - - @retry(max_retries=3, sleep=2) - def _search(self, query: str, num: int | None = None, return_score: bool = False) -> list[dict]: - """Search for documents using the remote retriever service.""" - num = num if num is not None else self.topk - url = f"{self.remote_url}/search" - - try: - response = requests.post( - url, - json={"query": query, "top_n": num, "return_score": return_score}, - timeout=30, - ) - response.raise_for_status() - - results = response.json() - return results - except requests.exceptions.Timeout: - logger.error(f"Search request timed out after 30 seconds for query: {query[:50]}...") - raise - except requests.exceptions.ConnectionError: - logger.error(f"Could not connect to search service at {url}") - raise - except requests.exceptions.RequestException as e: - logger.error(f"Search request failed: {e}", exc_info=True) - raise - except Exception as e: - logger.error(f"Unexpected search error: {str(e)}", exc_info=True) - raise - - @retry(max_retries=3, sleep=2) - def _batch_search( - self, queries: list[str], num: int | None = None, return_score: bool = False - ) -> list[list[dict]]: - """Batch search for documents using the remote retriever service.""" - num = num if num is not None else self.topk - url = f"{self.remote_url}/batch_search" - - try: - response = requests.post( - url, - json={"query": queries, "top_n": num, "return_score": return_score}, - timeout=60, - ) - response.raise_for_status() - results = response.json() - return results - except requests.exceptions.Timeout: - logger.error(f"Batch search request timed out after 60 seconds for {len(queries)} queries.") - raise - except requests.exceptions.ConnectionError: - logger.error(f"Could not connect to batch search service at {url}") - raise - except requests.exceptions.RequestException as e: - logger.error(f"Batch search request failed: {e}", exc_info=True) - raise - except Exception as e: - logger.error(f"Unexpected batch search error: {str(e)}", exc_info=True) - raise - - -class ReSearchPipeline(BasicPipeline): - """Pipeline for ReSearch method using Agent for generation and tool use.""" - - def __init__( - self, config: Config, retriever: BaseTextRetriever | None = None, generator: BaseGenerator | None = None - ): - """Initializes the ReSearchPipeline.""" - super().__init__(config) - logger.info("🔧 Initializing ReSearchPipeline...") - - self.retriever = retriever or RemoteRetriever(config) - - self.generator = generator or SGLRemoteGenerator(config) - - try: - self.tokenizer = AutoTokenizer.from_pretrained(config.generator_model_path, trust_remote_code=True) - if not self.tokenizer.pad_token: - logger.warning("Tokenizer does not have a pad token; setting to eos_token.") - self.tokenizer.pad_token = self.tokenizer.eos_token - self.tokenizer.padding_side = "left" - logger.info("✅ Tokenizer initialized.") - except Exception as e: - logger.error(f"Failed to initialize tokenizer: {e}", exc_info=True) - raise - - tokenizer_name = self.tokenizer.name_or_path.lower() - - if "deepseek-ai/deepseek-r1-distill" in tokenizer_name: - adapter = R1DistilTokenizerAdapter() - elif "llama" in tokenizer_name: - adapter = LlamaTokenizerAdapter() - else: - logger.warning(f"Unknown tokenizer type '{tokenizer_name}', defaulting to R1DistilTokenizerAdapter.") - adapter = R1DistilTokenizerAdapter() - logger.info(f"🔩 Using Tokenizer Adapter: {type(adapter).__name__}") - - def retriever_search(query: str, return_type=str, results: int = 5): - try: - search_results = self.retriever._search(query, num=results) - return self.format_search_results(search_results) - except Exception as e: - logger.error(f"Error during agent's retriever search for query '{query[:50]}...': {e}", exc_info=True) - return "Search failed due to an internal error." - - self.agent = Agent(adapter, search_fn=retriever_search) - logger.info("✅ Agent initialized.") - logger.info("✅ ReSearchPipeline initialized successfully.") - - def format_search_results(self, search_results: list[dict]) -> str: - """Formats search results into a string for the agent prompt.""" - if not search_results: - return "No results found." - max_content_len = 500 - formatted = "\n-------\n".join( - [ - f"Result {i + 1}: {r.get('contents', 'N/A')[:max_content_len]}{'...' if len(r.get('contents', '')) > max_content_len else ''}" - for i, r in enumerate(search_results) - ] - ) - formatted_str = f"{formatted}" - - return formatted_str - - def extract_search_query(self, text: str) -> str | None: - """Extract search query from text between tags.""" - pattern = re.compile(r"(.*?)", re.DOTALL) - matches = pattern.findall(text) - if matches: - query = matches[-1].strip() - return query - return None - - def extract_answer(self, text: str) -> str | None: - """Extract answer from text between tags.""" - pattern = re.compile(r"(.*?)", re.DOTALL) - matches = pattern.findall(text) - if matches: - answer = matches[-1].strip() - - return answer - - return None - - def run(self, dataset, do_eval: bool = True, pred_process_fun=None): - """Runs the ReSearch pipeline on the dataset using the Agent.""" - logger.info(f"🏃 Starting ReSearch pipeline run with {len(dataset)} items...") - - try: - questions = [item.question if hasattr(item, "question") else item["question"] for item in dataset] - - except (KeyError, AttributeError, TypeError) as e: - logger.error(f"Failed to extract questions from dataset items. Error: {e}", exc_info=True) - logger.error("Ensure dataset items have a 'question' key or attribute.") - return dataset - - agent_max_generations = getattr(self.config, "agent_max_generations", 32) - generator_max_output_len = getattr(self.config, "generator_max_output_len", 24576) - - try: - logger.info(f"🤖 Running agent inference for {len(questions)} questions...") - agent_outputs: AgenticOutputs = self.agent.run_agent( - generate_fn=self.generator.generate, - tokenizer=self.tokenizer, - questions=questions, - max_generations=agent_max_generations, - max_new_tokens=generator_max_output_len, - ) - final_responses = agent_outputs.final_response_str - logger.info(f"✅ Agent inference completed. Received {len(final_responses)} final responses.") - - except Exception as e: - logger.error(f"Agent run failed during inference: {e}", exc_info=True) - logger.warning("Agent run failed, attempting evaluation with potentially incomplete results.") - for item in dataset: - if hasattr(item, "update_output"): - item.update_output("pred", "AGENT_ERROR") - elif isinstance(item, dict): - item["pred"] = "AGENT_ERROR" - - logger.info("📝 Extracting answers and updating dataset items...") - num_updated = 0 - num_missing_answers = 0 - if len(final_responses) == len(dataset): - for i, item in enumerate(dataset): - response = final_responses[i] - answer = self.extract_answer(response) - pred_to_save = answer if answer is not None else "" - - if answer is None: - num_missing_answers += 1 - - if hasattr(item, "update_output"): - item.update_output("pred", pred_to_save) - item.update_output("final_response", response) - num_updated += 1 - elif isinstance(item, dict): - item["pred"] = pred_to_save - item["final_response"] = response - num_updated += 1 - else: - logger.warning(f"Item {i} has unknown type {type(item)}, cannot update with prediction.") - - logger.info(f"Updated {num_updated}/{len(dataset)} dataset items with predictions.") - if num_missing_answers > 0: - logger.warning(f"{num_missing_answers} items had no tag.") - else: - logger.error( - f"Mismatch between dataset size ({len(dataset)}) and number of agent responses ({len(final_responses)}). Cannot reliably update dataset." - ) - for item in dataset: - if hasattr(item, "update_output"): - item.update_output("pred", "RESPONSE_COUNT_MISMATCH") - elif isinstance(item, dict): - item["pred"] = "RESPONSE_COUNT_MISMATCH" - - if do_eval: - logger.info("📊 Evaluating results using BasicPipeline.evaluate...") - try: - dataset = self.evaluate(dataset, do_eval=True, pred_process_fun=pred_process_fun) - logger.info("✅ Evaluation completed via base class method.") - except Exception as e: - logger.error(f"Error during BasicPipeline.evaluate: {e}", exc_info=True) - logger.warning("Evaluation may be incomplete.") - else: - logger.info("Skipping evaluation step as do_eval=False.") - - logger.info("✅ ReSearch pipeline run finished.") - return dataset - - -class SGLRemoteGenerator(BaseGenerator): - """Class for decoder-only generator, based on SGLang remote service.""" - - def __init__(self, config: Config): - """Initializes the SGLRemoteGenerator.""" - super().__init__(config) - logger.info("🔧 Initializing SGLRemoteGenerator...") - sgl_url = getattr(config, "sgl_remote_url", "localhost:8002") - self.sgl_remote_url = f"http://{sgl_url}/generate" - self.health_check_url = f"http://{sgl_url}/health" - logger.info(f"🔗 Remote Generator URL: {self.sgl_remote_url}") - self.model_path = getattr(config, "generator_model_path", None) - if not self.model_path: - logger.error("generator_model_path not found in config!") - raise ValueError("generator_model_path is required for SGLRemoteGenerator") - - try: - self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) - logger.info("✅ Tokenizer loaded for generator.") - except Exception as e: - logger.error(f"Failed to load tokenizer for generator from {self.model_path}: {e}", exc_info=True) - raise - - self.generation_params = getattr(config, "generation_params", {}) - self.config = config - - self._check_health() - - def _check_health(self): - """Checks the health of the remote generator service.""" - try: - test_response = requests.get(self.health_check_url, timeout=5) - test_response.raise_for_status() - logger.info("✅ Remote generator service is available") - except requests.exceptions.RequestException as e: - logger.error(f"Could not connect or verify remote generator service at {self.health_check_url}: {str(e)}") - logger.warning("Please ensure the SGLang service is running and accessible.") - - @retry(max_retries=5, sleep=2) - def generate( - self, - input_list: list[str] | str, - return_raw_output: bool = False, - return_scores: bool = False, - **params, - ) -> list[str] | tuple[list[str], list[list[float]]] | list[dict]: - """Generates text using the remote SGLang service.""" - if isinstance(input_list, str): - input_list = [input_list] - if not isinstance(input_list, list) or not all(isinstance(item, str) for item in input_list): - raise ValueError("Input must be a string or a list of strings.") - - batch_size = len(input_list) - data_to_remote = {"text": input_list} - - effective_params = deepcopy(self.generation_params) - effective_params.update(params) - - curr_sampling_params = {} - if effective_params.get("do_sample", True) is False: - curr_sampling_params["temperature"] = 0.0 - else: - curr_sampling_params["temperature"] = effective_params.get( - "temperature", getattr(self.config, "temperature", 0.7) - ) - - default_max_tokens = getattr(self.config, "generator_max_output_len", 1024) - curr_sampling_params["max_new_tokens"] = effective_params.get("max_new_tokens", default_max_tokens) - - stop_sequences = effective_params.get("stop", []) - if isinstance(stop_sequences, str): - stop_sequences = [stop_sequences] - if stop_sequences: - curr_sampling_params["stop"] = stop_sequences - - keys_to_remove = ["do_sample", "temperature", "max_new_tokens", "stop"] - for key in keys_to_remove: - effective_params.pop(key, None) - - if "top_p" in effective_params: - curr_sampling_params["top_p"] = effective_params["top_p"] - if "top_k" in effective_params: - curr_sampling_params["top_k"] = effective_params["top_k"] - - data_to_remote["sampling_params"] = curr_sampling_params - - if return_scores: - data_to_remote["return_logprob"] = True - data_to_remote["top_logprobs_num"] = getattr(self.config, "top_logprobs_num", 2) - - try: - response = requests.post( - self.sgl_remote_url, json=data_to_remote, timeout=120, headers={"Content-Type": "application/json"} - ) - response.raise_for_status() - - response_list = response.json() - - if return_raw_output: - return response_list - - generated_text = [] - for item in response_list: - text = item.get("text", "") - finish_reason = item.get("meta_info", {}).get("finish_reason", {}) - matched_stop = finish_reason.get("matched") - if matched_stop and curr_sampling_params.get("stop") and matched_stop in curr_sampling_params["stop"]: - text += matched_stop - generated_text.append(text) - - if return_scores: - scores = [] - for resp_item in response_list: - logprobs_list = resp_item.get("meta_info", {}).get("output_token_logprobs", []) - token_scores = [ - np.exp(logprob[0]) if (logprob and len(logprob) > 0) else 0.0 for logprob in logprobs_list - ] - scores.append(token_scores) - return generated_text, scores - else: - return generated_text - - except requests.exceptions.Timeout: - logger.error("Generation request timed out after 120 seconds.") - raise - except requests.exceptions.ConnectionError: - logger.error(f"Could not connect to remote generator service at {self.sgl_remote_url}.") - raise - except requests.exceptions.RequestException as e: - logger.error(f"Network error during generation: {str(e)}", exc_info=True) - raise - except json.JSONDecodeError: - response_text = "Unknown (error occurred before response object assignment)" - if "response" in locals() and hasattr(response, "text"): - response_text = response.text[:500] - logger.error( - f"Failed to decode JSON response from {self.sgl_remote_url}. Response text: {response_text}...", - exc_info=True, - ) - raise - except Exception as e: - logger.error(f"Unexpected error during generation: {str(e)}", exc_info=True) - raise - - -def load_dataset_items(config: Config, split: str) -> list[dict | object]: - """Loads dataset items using flashrag's get_dataset.""" - logger.info(f"📚 Loading dataset: {config.dataset_name}, Split: {split}") - try: - all_splits = get_dataset(config) - if split not in all_splits: - logger.error( - f"Split '{split}' not found in dataset '{config.dataset_name}'. Available splits: {list(all_splits.keys())}" - ) - return [] - dataset_items = all_splits[split] - logger.info(f"Successfully loaded {len(dataset_items)} items for split '{split}'.") - - return dataset_items - except FileNotFoundError: - logger.error( - f"Dataset files not found for '{config.dataset_name}' in '{config.data_dir}'. Check config and paths." - ) - return [] - except Exception as e: - logger.error(f"Error loading dataset using get_dataset: {e}", exc_info=True) - return [] - - -def save_results(args: argparse.Namespace, config: Config, result_dataset, run_duration: float): - """Saves summary and debug information.""" - logger.info("💾 Saving results...") - summary_file = os.path.join(args.save_dir, f"{args.save_note}_summary.txt") - debug_file = os.path.join(args.save_dir, f"{args.save_note}_debug.json") - - num_items = len(result_dataset) - - logger.info(f"Saving summary results to {summary_file}...") - try: - with open(summary_file, "w", encoding="utf-8") as f: - f.write("EVALUATION SUMMARY\n") - f.write("=================\n\n") - f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") - f.write(f"Run Duration: {run_duration:.2f} seconds\n") - f.write(f"Dataset: {config.dataset_name} ({args.split} split)\n") - f.write(f"Model: {config.generator_model_path}\n") - f.write(f"Retriever: {config.remote_retriever_url}\n") - f.write(f"Agent Max Generations: {getattr(config, 'agent_max_generations', 'N/A')}\n") - f.write(f"Generator Max Output Len: {getattr(config, 'generator_max_output_len', 'N/A')}\n\n") - f.write(f"Total items processed: {num_items}\n") - f.write("\nNote: Verification was skipped in this run.\n") - f.write("Note: Overall metrics (like EM, F1) are usually printed to console by evaluate method.\n") - - logger.info(f"✅ Summary saved to {summary_file}") - except Exception as e: - logger.error(f"Error saving summary file '{summary_file}': {e}", exc_info=True) - - logger.info(f"Saving debug information (predictions & responses) to {debug_file}...") - try: - debug_data = [] - for i, item in enumerate(result_dataset): - item_data: dict[str, object] = {} - - def get_item_value(data_item, key_or_attr: str) -> str | int | float | list | bool | None: - if isinstance(data_item, dict): - return data_item.get(key_or_attr) - elif hasattr(data_item, key_or_attr): - return getattr(data_item, key_or_attr) - return None - - item_data["item_index"] = i - item_data["question"] = get_item_value(item, "question") - item_data["prediction"] = get_item_value(item, "pred") - item_data["final_response"] = get_item_value(item, "final_response") - - gt_answer_val = None - try: - gt_answer_val = get_item_value(item, "answer") - if gt_answer_val is None: - answers_list = get_item_value(item, "answers") - if isinstance(answers_list, list) and answers_list: - raw_ans = answers_list[0] - if isinstance(raw_ans, (str, int, float, bool)): - gt_answer_val = raw_ans - else: - gt_answer_val = str(raw_ans) - elif not isinstance(gt_answer_val, (str, int, float, bool)): - gt_answer_val = str(gt_answer_val) - except Exception as e: - logger.warning(f"Could not safely get ground truth for item {i}: {e}") - gt_answer_val = "ERROR_GETTING_ANSWER" - item_data["ground_truth"] = gt_answer_val - - eval_score_val = None - try: - eval_score_val = get_item_value(item, "score") - if not isinstance(eval_score_val, (str, int, float, bool, type(None))): - eval_score_val = str(eval_score_val) - except Exception as e: - logger.warning(f"Could not safely get score for item {i}: {e}") - eval_score_val = "ERROR_GETTING_SCORE" - item_data["eval_score"] = eval_score_val - - debug_data.append(item_data) - - with open(debug_file, "w", encoding="utf-8") as f: - json.dump(debug_data, f, indent=2, ensure_ascii=False) - logger.info(f"✅ Debug information saved to {debug_file}") - except Exception as e: - logger.error(f"Error saving debug file '{debug_file}': {e}", exc_info=True) - - -def research(args: argparse.Namespace, config: Config): - """Main function to run the research evaluation pipeline.""" - logger.info("🚀 Starting research pipeline execution...") - start_time = time.time() - - test_data = load_dataset_items(config, args.split) - if not test_data: - logger.error("Failed to load test data. Exiting.") - return - - try: - logger.info("🏗️ Building ReSearchPipeline...") - pipeline = ReSearchPipeline(config) - logger.info("✅ Pipeline built successfully.") - except Exception as e: - logger.error(f"Failed to initialize ReSearchPipeline: {e}", exc_info=True) - return - - agent_max_generations = getattr(config, "agent_max_generations", 32) - generator_max_output_len = getattr(config, "generator_max_output_len", 24576) - - try: - logger.info("🏃 Starting pipeline run...") - result_dataset = pipeline.run(test_data, do_eval=True) - logger.info("✅ Pipeline run completed.") - except Exception as e: - logger.error(f"Error during pipeline run: {e}", exc_info=True) - result_dataset = test_data - logger.warning("Pipeline run failed, attempting to save inputs/partial results.") - - run_duration = time.time() - start_time - logger.info(f"Total run duration: {run_duration:.2f} seconds.") - save_results(args, config, result_dataset, run_duration) - - logger.info("🏁 Research pipeline execution finished.") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Running ReSearch Evaluation Pipeline") - parser.add_argument( - "--config_path", type=str, default="./eval_config.yaml", help="Path to the main FlashRAG config file." - ) - parser.add_argument( - "--dataset_name", - type=str, - default="bamboogle", - help="Name of the dataset (must match config or data_dir structure).", - ) - parser.add_argument( - "--split", type=str, default="test", help="Dataset split to evaluate (e.g., test, validation)." - ) - parser.add_argument("--save_dir", type=str, default="./output_logs", help="Directory to save logs and results.") - parser.add_argument("--save_note", type=str, default="research_run", help="A note to prepend to saved filenames.") - - parser.add_argument("--data_dir", type=str, help="Override data directory specified in config.") - parser.add_argument( - "--sgl_remote_url", type=str, help="Override SGLang remote generator URL (e.g., localhost:8002)." - ) - parser.add_argument( - "--remote_retriever_url", type=str, help="Override remote retriever URL (e.g., localhost:8001)." - ) - parser.add_argument("--generator_model_path", type=str, help="Override generator model path specified in config.") - parser.add_argument("--retriever_topk", type=int, help="Override retriever top K.") - parser.add_argument("--generator_max_output_len", type=int, help="Override generator max output length.") - parser.add_argument("--agent_max_generations", type=int, help="Override agent max interaction turns.") - - args = parser.parse_args() - logger.info(f"Starting evaluation script with arguments: {args}") - - try: - os.makedirs(args.save_dir, exist_ok=True) - logger.info(f"💾 Logs and results will be saved to: {args.save_dir}") - except OSError as e: - logger.error(f"Could not create save directory '{args.save_dir}': {e}", exc_info=True) - exit(1) - - config_overrides = { - k: v - for k, v in vars(args).items() - if v is not None - and k - not in [ - "config_path", - "dataset_name", - "split", - "save_dir", - "save_note", - ] - } - - logger.info(f"🔧 Loading configuration from: {args.config_path}") - try: - config = Config(args.config_path, config_dict=config_overrides) - config.dataset_name = args.dataset_name - if args.data_dir: - config.data_dir = args.data_dir - - logger.info(f"Effective data_dir: {getattr(config, 'data_dir', 'N/A')}") - logger.info(f"Effective generator_model_path: {getattr(config, 'generator_model_path', 'N/A')}") - logger.info(f"Effective sgl_remote_url: {getattr(config, 'sgl_remote_url', 'N/A')}") - logger.info(f"Effective remote_retriever_url: {getattr(config, 'remote_retriever_url', 'N/A')}") - - logger.info("✅ Config loaded and potentially overridden by CLI arguments.") - - config["dataset_path"] = os.path.join(config.data_dir, config.dataset_name) - - except FileNotFoundError: - logger.error(f"Config file not found at '{args.config_path}'. Please check the path.") - exit(1) - except Exception as e: - logger.error(f"Error loading or processing configuration: {e}", exc_info=True) - exit(1) - - research(args, config) diff --git a/scripts/serving/download_flashrag_datasets.py b/scripts/serving/download_flashrag_datasets.py deleted file mode 100644 index 53b58f3..0000000 --- a/scripts/serving/download_flashrag_datasets.py +++ /dev/null @@ -1,68 +0,0 @@ -import argparse -import os -import zipfile - -from dotenv import load_dotenv -from huggingface_hub import snapshot_download - -from config import DATA_DIR - - -def parse_args() -> argparse.Namespace: - """Parse command line arguments. - - Returns: - argparse.Namespace: Parsed arguments - """ - parser = argparse.ArgumentParser(description="Download FlashRAG datasets from HuggingFace Hub") - parser.add_argument( - "--repo-id", - type=str, - default="RUC-NLPIR/FlashRAG_datasets", - help="HuggingFace repository IDs", - ) - parser.add_argument( - "--local-dir", - type=str, - default=DATA_DIR / "flashrag_datasets", - help="Local directory to save model", - ) - - return parser.parse_args() - - -def main(): - """Main function to download model.""" - args = parse_args() - load_dotenv(override=True) - - # Configuration - HF_TOKEN = os.getenv("HF_TOKEN") - - ALLOW_PATTERNS = [ - "*retrieval-corpus*", - "*bamboogle*", - "*nq*", - ] - - # Download the model - snapshot_download( - token=HF_TOKEN, - repo_id=args.repo_id, - local_dir=args.local_dir, - repo_type="dataset", - # ignore_patterns=IGNORE_PATTERNS, - allow_patterns=ALLOW_PATTERNS, - ) - - # unzip data/flashrag_datasets/retrieval-corpus/wiki18_100w.zip - print("Unzipping wiki18_100w.zip. Might take a while...") - zip_file_path = os.path.join(args.local_dir, "retrieval-corpus", "wiki18_100w.zip") - with zipfile.ZipFile(zip_file_path, "r") as zip_ref: - zip_ref.extractall(args.local_dir) - - print(f"✅ Done: {args.repo_id} -> {args.local_dir}") - - -if __name__ == "__main__": - main() diff --git a/scripts/serving/download_flashrag_index.py b/scripts/serving/download_flashrag_index.py deleted file mode 100644 index efc1e8e..0000000 --- a/scripts/serving/download_flashrag_index.py +++ /dev/null @@ -1,43 +0,0 @@ -"""Download and extract FlashRAG index.""" - -import os -import zipfile - -import requests -from tqdm import tqdm - -from config import DATA_DIR - -# Constants -URL = "https://www.modelscope.cn/datasets/hhjinjiajie/FlashRAG_Dataset/resolve/master/retrieval_corpus/wiki18_100w_e5_index.zip" -ZIP_NAME = "wiki18_100w_e5_index.zip" -zip_path = DATA_DIR / ZIP_NAME - -# Download with progress bar -print("📥 Downloading index...") -response = requests.get(URL, stream=True) -total_size = int(response.headers.get("content-length", 0)) - -with ( - open(zip_path, "wb") as f, - tqdm( - desc=ZIP_NAME, - total=total_size, - unit="iB", - unit_scale=True, - unit_divisor=1024, - ) as bar, -): - for data in response.iter_content(chunk_size=1024): - size = f.write(data) - bar.update(size) - -# Extract -print("📦 Extracting index...") -with zipfile.ZipFile(zip_path, "r") as zip_ref: - zip_ref.extractall(DATA_DIR) - -# Clean up zip -os.remove(zip_path) -print("✅ Download and extraction completed successfully!") -print(f"Index file is at: {DATA_DIR}/data00/jiajie_jin/flashrag_indexes/wiki_dpr_100w/e5_flat_inner.index") diff --git a/scripts/serving/download_generator_model.py b/scripts/serving/download_generator_model.py deleted file mode 100644 index 27689ac..0000000 --- a/scripts/serving/download_generator_model.py +++ /dev/null @@ -1,54 +0,0 @@ -import argparse -import os - -from dotenv import load_dotenv -from huggingface_hub import snapshot_download - -from config import GENERATOR_MODEL_DIR, GENERATOR_MODEL_REPO_ID - - -def parse_args() -> argparse.Namespace: - """Parse command line arguments. - - Returns: - argparse.Namespace: Parsed arguments - """ - parser = argparse.ArgumentParser(description="Download model from HuggingFace Hub") - parser.add_argument( - "--repo-id", - type=str, - default=GENERATOR_MODEL_REPO_ID, - help="HuggingFace repository ID", - ) - parser.add_argument( - "--local-dir", - type=str, - default=GENERATOR_MODEL_DIR, - help="Local directory to save model", - ) - - return parser.parse_args() - - -def main(): - """Main function to download model.""" - args = parse_args() - load_dotenv(override=True) - - # Configuration - HF_TOKEN = os.getenv("HF_TOKEN") - - print("Downloading model to", args.local_dir) - - # Download the model - snapshot_download( - token=HF_TOKEN, - repo_id=args.repo_id, - local_dir=args.local_dir, - repo_type="model", - ) - print(f"✅ Done: {args.repo_id} -> {args.local_dir}") - - -if __name__ == "__main__": - main() diff --git a/scripts/serving/download_retriever_model.py b/scripts/serving/download_retriever_model.py deleted file mode 100644 index 28bd4db..0000000 --- a/scripts/serving/download_retriever_model.py +++ /dev/null @@ -1,54 +0,0 @@ -import argparse -import os - -from dotenv import load_dotenv -from huggingface_hub import snapshot_download - -from config import RETRIEVER_MODEL_DIR, RETRIEVER_MODEL_REPO_ID - - -def parse_args() -> argparse.Namespace: - """Parse command line arguments. - - Returns: - argparse.Namespace: Parsed arguments - """ - parser = argparse.ArgumentParser(description="Download model from HuggingFace Hub") - parser.add_argument( - "--repo-id", - type=str, - default=RETRIEVER_MODEL_REPO_ID, - help="HuggingFace repository ID", - ) - parser.add_argument( - "--local-dir", - type=str, - default=RETRIEVER_MODEL_DIR, - help="Local directory to save model", - ) - - return parser.parse_args() - - -def main(): - """Main function to download model.""" - args = parse_args() - load_dotenv(override=True) - - # Configuration - HF_TOKEN = os.getenv("HF_TOKEN") - - print("Downloading model to", args.local_dir) - - # Download the model - snapshot_download( - token=HF_TOKEN, - repo_id=args.repo_id, - local_dir=args.local_dir, - repo_type="model", - ) - print(f"✅ Done: {args.repo_id} -> {args.local_dir}") - - -if __name__ == "__main__": - main() diff --git a/scripts/serving/retriever_config.yaml b/scripts/serving/retriever_config.yaml deleted file mode 100644 index fe5dd1a..0000000 --- a/scripts/serving/retriever_config.yaml +++ /dev/null @@ -1,9 +0,0 @@ -# ------------------------------------------------Environment Settings------------------------------------------------# -gpu_id: "0" - -# -------------------------------------------------Retrieval Settings------------------------------------------------# -# If set the name, the model path will be find in global paths -retrieval_method: "e5" # name or path of the retrieval model. -index_path: "/mnt/nas/thinhlpg/data/data00/jiajie_jin/flashrag_indexes/wiki_dpr_100w/e5_flat_inner.index" # path to the indexed file -faiss_gpu: False # whether use gpu to hold index -corpus_path: "/mnt/nas/thinhlpg/code/DeepSearch/data/flashrag_datasets/wiki18_100w.jsonl" # path to corpus in '.jsonl' format that store the documents diff --git a/scripts/serving/serve_generator.py b/scripts/serving/serve_generator.py deleted file mode 100644 index 571a553..0000000 --- a/scripts/serving/serve_generator.py +++ /dev/null @@ -1,127 +0,0 @@ -import subprocess -import sys -from pathlib import Path - -# Add project root to sys.path to allow importing config -# Assuming the script is at DeepSearch/scripts/serving/serve_generator.py -# The project root (DeepSearch) is parents[2] -PROJ_ROOT = Path(__file__).resolve().parents[2] -if str(PROJ_ROOT) not in sys.path: - sys.path.append(str(PROJ_ROOT)) - -# Import after adjusting sys.path -try: - from config import ( - GENERATOR_MODEL_REPO_ID, - GENERATOR_SERVER_PORT, - MODEL_CONFIG, - logger, - ) -except ImportError as e: - # Use print here as logger might not be available if import failed - print( - f"Error importing config: {e}. Make sure config.py is in the project root ({PROJ_ROOT}) and added to sys.path." - ) - sys.exit(1) - - -def launch_sglang_server( - model_id: str, - port: int, - context_length: int, - host: str = "0.0.0.0", - dtype: str = "bfloat16", -) -> None: - """Launches the SGLang server using specified configurations. - - Args: - model_id: The Hugging Face repository ID of the model. - port: The port number for the server. - context_length: The maximum context length for the model. - host: The host address for the server. - dtype: The data type for the model (e.g., 'bfloat16', 'float16'). - """ - command = [ - sys.executable, # Use the current Python interpreter - "-m", - "sglang.launch_server", - "--model-path", - model_id, - "--context-length", - str(context_length), - "--enable-metrics", - "--dtype", - dtype, - "--host", - host, - "--port", - str(port), - "--mem-fraction-static", - "0.5", - "--trust-remote-code", - # Recommended by SGLang for stability sometimes - "--disable-overlap", - # Can sometimes cause issues - "--disable-radix-cache", - ] - - # Log the command clearly - command_str = " ".join(command) - logger.info(f"🚀 Launching SGLang server with command: {command_str}") - - process = None # Initialize process to None - try: - # Use Popen to start the server process - # It runs in the foreground relative to this script, - # but allows us to catch KeyboardInterrupt cleanly. - process = subprocess.Popen(command) - # Wait for the process to complete (e.g., user interruption) - process.wait() - # Check return code after waiting - if process.returncode != 0: - logger.error(f"💥 SGLang server process exited with error code: {process.returncode}") - sys.exit(process.returncode) - else: - logger.info("✅ SGLang server process finished gracefully.") - - except FileNotFoundError: - logger.error(f"💥 Error: Python executable or sglang module not found.") - logger.error(f"Ensure '{sys.executable}' is correct and sglang is installed.") - sys.exit(1) - except KeyboardInterrupt: - logger.info("🛑 SGLang server launch interrupted by user. Stopping server...") - # Attempt to terminate the process gracefully - if process and process.poll() is None: # Check if process exists and is running - process.terminate() - try: - process.wait(timeout=5) # Wait a bit for termination - logger.info("✅ Server terminated gracefully.") - except subprocess.TimeoutExpired: - logger.warning("⚠️ Server did not terminate gracefully, forcing kill.") - process.kill() - sys.exit(0) # Exit cleanly after interrupt - except Exception as e: - # Catch any other unexpected exceptions during launch or waiting - logger.error(f"🚨 An unexpected error occurred: {e}") - # Ensure process is cleaned up if it exists - if process and process.poll() is None: - process.kill() - sys.exit(1) - - -if __name__ == "__main__": - # Get context length from config, default to 8192 - context_len = MODEL_CONFIG.get("max_seq_length", 8192) - - logger.info("----------------------------------------------------") - logger.info("✨ Starting SGLang Generator Server ✨") - logger.info(f" Model ID: {GENERATOR_MODEL_REPO_ID}") - logger.info(f" Port: {GENERATOR_SERVER_PORT}") - logger.info(f" Context Length: {context_len}") - logger.info("----------------------------------------------------") - - launch_sglang_server( - model_id=GENERATOR_MODEL_REPO_ID, - port=GENERATOR_SERVER_PORT, - context_length=context_len, - ) diff --git a/scripts/serving/serve_retriever.py b/scripts/serving/serve_retriever.py deleted file mode 100644 index 229351c..0000000 --- a/scripts/serving/serve_retriever.py +++ /dev/null @@ -1,113 +0,0 @@ -import argparse -import asyncio -from collections import deque -from typing import List, Tuple, Union - -from fastapi import FastAPI, HTTPException -from flashrag.config import Config -from flashrag.utils import get_retriever -from pydantic import BaseModel - -from config import RETRIEVER_SERVER_PORT - -app = FastAPI() - -retriever_list = [] -available_retrievers = deque() -retriever_semaphore = None - - -def init_retriever(args): - global retriever_semaphore - config = Config(args.config) - for i in range(args.num_retriever): - print(f"Initializing retriever {i + 1}/{args.num_retriever}") - retriever = get_retriever(config) - retriever_list.append(retriever) - available_retrievers.append(i) - # create a semaphore to limit the number of retrievers that can be used concurrently - retriever_semaphore = asyncio.Semaphore(args.num_retriever) - - -@app.get("/health") -async def health_check(): - return {"status": "healthy", "retrievers": {"total": len(retriever_list), "available": len(available_retrievers)}} - - -class QueryRequest(BaseModel): - query: str - top_n: int = 10 - return_score: bool = False - - -class BatchQueryRequest(BaseModel): - query: List[str] - top_n: int = 10 - return_score: bool = False - - -class Document(BaseModel): - id: str - contents: str - - -@app.post("/search", response_model=Union[Tuple[List[Document], List[float]], List[Document]]) -async def search(request: QueryRequest): - query = request.query - top_n = request.top_n - return_score = request.return_score - - if not query or not query.strip(): - print(f"Query content cannot be empty: {query}") - raise HTTPException(status_code=400, detail="Query content cannot be empty") - - async with retriever_semaphore: - retriever_idx = available_retrievers.popleft() - try: - if return_score: - results, scores = retriever_list[retriever_idx].search(query, top_n, return_score) - return [Document(id=result["id"], contents=result["contents"]) for result in results], scores - else: - results = retriever_list[retriever_idx].search(query, top_n, return_score) - return [Document(id=result["id"], contents=result["contents"]) for result in results] - finally: - available_retrievers.append(retriever_idx) - - -@app.post("/batch_search", response_model=Union[List[List[Document]], Tuple[List[List[Document]], List[List[float]]]]) -async def batch_search(request: BatchQueryRequest): - query = request.query - top_n = request.top_n - return_score = request.return_score - - async with retriever_semaphore: - retriever_idx = available_retrievers.popleft() - try: - if return_score: - results, scores = retriever_list[retriever_idx].batch_search(query, top_n, return_score) - return [ - [Document(id=result["id"], contents=result["contents"]) for result in results[i]] - for i in range(len(results)) - ], scores - else: - results = retriever_list[retriever_idx].batch_search(query, top_n, return_score) - return [ - [Document(id=result["id"], contents=result["contents"]) for result in results[i]] - for i in range(len(results)) - ] - finally: - available_retrievers.append(retriever_idx) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--config", type=str, default="./retriever_config.yaml") - parser.add_argument("--num_retriever", type=int, default=1) - parser.add_argument("--port", type=int, default=RETRIEVER_SERVER_PORT) - args = parser.parse_args() - - init_retriever(args) - - import uvicorn - - uvicorn.run(app, host="0.0.0.0", port=args.port) diff --git a/scripts/train_data/build_musique_index.py b/scripts/train_data/build_musique_index.py deleted file mode 100644 index 0cb0106..0000000 --- a/scripts/train_data/build_musique_index.py +++ /dev/null @@ -1,135 +0,0 @@ -import json -import math # Import math for ceiling division -import sys -import traceback # Import traceback -from pathlib import Path - -import pandas as pd - -# Add project root to Python path if needed (adjust relative path as necessary) -project_root = Path(__file__).resolve().parent.parent -sys.path.append(str(project_root)) - -from src.embeddings import CustomHuggingFaceEmbeddings - -# Import FAISS after potentially adding to sys.path -try: - from langchain_community.vectorstores import FAISS -except ImportError: - print("Error: langchain_community or FAISS not installed. Please install with 'pip install langchain faiss-cpu'") - sys.exit(1) - - -def build_faiss_index_from_csv(csv_path: str, index_save_path: str, batch_size: int = 128) -> None: - """Builds a FAISS index from a CSV containing paragraph content and metadata. - - Reads a CSV file, generates embeddings for the 'content' column in batches, - and saves the FAISS index files (index.faiss, index.pkl) locally. - - Args: - csv_path: Path to the input CSV file (e.g., data/processed/paragraphs.csv). - index_save_path: Path to the directory where the index files should be saved. - batch_size: Number of texts to process in each embedding batch. - """ - print(f"Loading paragraphs from {csv_path}") - try: - df = pd.read_csv(csv_path) - except FileNotFoundError: - print(f"Error: CSV file not found at {csv_path}. Please run the extraction script first.") - return - except Exception as e: - print(f"Error reading CSV file: {e}") - return - - if "content" not in df.columns or "metadata" not in df.columns: - print("Error: CSV file must contain 'content' and 'metadata' columns.") - return - - if df.empty: - print("Warning: Input CSV file is empty. No index will be built.") - return - - # Prepare documents for FAISS - texts = df["content"].astype(str).tolist() - metadatas = [] - try: - metadatas = [json.loads(m) for m in df["metadata"].tolist()] - print(f"Prepared {len(texts)} texts and {len(metadatas)} metadatas.") - except json.JSONDecodeError as e: - print(f"Error parsing metadata JSON: {e}. Check the format in {csv_path}") - traceback.print_exc() # Print traceback for JSON errors - return - except Exception as e: - print(f"Error processing metadata: {e}") - traceback.print_exc() # Print traceback for other metadata errors - return - - if not texts or not metadatas or len(texts) != len(metadatas): - print(f"Error: Mismatch or empty texts/metadatas. Texts: {len(texts)}, Metadatas: {len(metadatas)}") - return - - print("Initializing embeddings model...") - try: - embeddings = CustomHuggingFaceEmbeddings() - except Exception as e: - print(f"Error initializing embeddings model: {e}") - traceback.print_exc() - return - print("Embeddings model initialized successfully.") - - vectorstore = None - num_batches = math.ceil(len(texts) / batch_size) - print(f"Processing {len(texts)} texts in {num_batches} batches of size {batch_size}...") - - for i in range(num_batches): - start_idx = i * batch_size - end_idx = min((i + 1) * batch_size, len(texts)) - batch_texts = texts[start_idx:end_idx] - batch_metadatas = metadatas[start_idx:end_idx] - print(f" Processing batch {i + 1}/{num_batches} (indices {start_idx}-{end_idx - 1})...") - - try: - if i == 0: - # Initialize the vector store with the first batch - print(f" Initializing FAISS index with first batch...") - vectorstore = FAISS.from_texts(texts=batch_texts, embedding=embeddings, metadatas=batch_metadatas) - print(" FAISS index initialized.") - else: - # Add subsequent batches to the existing store - if vectorstore is None: - print("Error: vectorstore is None after first batch, cannot add more texts.") - return # Should not happen if first batch succeeded - print(f" Adding batch {i + 1} to FAISS index...") - vectorstore.add_texts(texts=batch_texts, metadatas=batch_metadatas) - print(f" Batch {i + 1} added.") - - except Exception as e: - print(f"Error processing batch {i + 1} (indices {start_idx}-{end_idx - 1}): {e}") - traceback.print_exc() - print("Stopping index creation due to error in batch processing.") - return # Exit if any batch fails - - if vectorstore is None: - print("Error: Failed to create or add any data to the vectorstore.") - return - - # Save the completed index - try: - print(f"Attempting to save final FAISS index files to directory: {index_save_path}") - # Ensure the target directory exists before saving - Path(index_save_path).mkdir(parents=True, exist_ok=True) - vectorstore.save_local(index_save_path) - print(f"Successfully saved final FAISS index files (index.faiss, index.pkl) to: {index_save_path}") - except Exception as e: - print(f"Error during final vectorstore.save_local to {index_save_path}: {e}") - traceback.print_exc() - - -if __name__ == "__main__": - # Define paths relative to this script or use absolute paths - PROCESSED_DIR = Path("data/processed") - INPUT_CSV = str(PROCESSED_DIR / "paragraphs.csv") - # FAISS save_local will save index.faiss and index.pkl in this directory - INDEX_SAVE_DIR = str(PROCESSED_DIR) # Save directly to processed dir - - build_faiss_index_from_csv(INPUT_CSV, INDEX_SAVE_DIR, batch_size=128) diff --git a/scripts/train_data/download_data_musique.sh b/scripts/train_data/download_data_musique.sh deleted file mode 100644 index 9860a68..0000000 --- a/scripts/train_data/download_data_musique.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/bin/bash -# This script is taken from https://github.com/StonyBrookNLP/musique with slight modifications - -set -e -set -x - -# If gdown doesn't work, you can download files from mentioned URLs manually -# and put them at appropriate locations. -pip install gdown - -ZIP_NAME="musique_v1.0.zip" - -# URL: https://drive.google.com/file/d/1tGdADlNjWFaHLeZZGShh2IRcpO6Lv24h/view?usp=sharing -gdown --id 1tGdADlNjWFaHLeZZGShh2IRcpO6Lv24h --output $ZIP_NAME - -TARGET_DIR="./data/raw" -mkdir -p $TARGET_DIR -unzip -o $(basename $ZIP_NAME) -d $TARGET_DIR # Extract directly into target - -# Move contents from the extracted 'data' folder up one level -mv $TARGET_DIR/data/* $TARGET_DIR/ - -# Clean up the empty directory and the zip -rm -rf $TARGET_DIR/data -rm $ZIP_NAME - -# TODO: prevent these from zipping in. -rm -rf __MACOSX -# Clean up potential extracted .DS_Store -rm -f $TARGET_DIR/.DS_Store diff --git a/scripts/train_data/extract_musique_paragraphs.py b/scripts/train_data/extract_musique_paragraphs.py deleted file mode 100644 index 8b4ef7c..0000000 --- a/scripts/train_data/extract_musique_paragraphs.py +++ /dev/null @@ -1,101 +0,0 @@ -import json -import sys -from collections import defaultdict # Use defaultdict for cleaner accumulation -from pathlib import Path - -import pandas as pd - -# Add project root to Python path if needed (adjust relative path as necessary) -# project_root = Path(__file__).resolve().parent.parent -# sys.path.append(str(project_root)) -# from config import logger # Assuming you have a logger setup - - -def extract_unique_paragraphs(input_paths: list[str], output_csv_path: str) -> None: - """Extracts unique paragraphs from specified JSONL files. - - Reads Musique JSONL files (train, dev, test), finds unique paragraphs - (regardless of is_supporting flag), combines title and text, - tracks source question IDs, and saves to CSV. - - Args: - input_paths: A list of paths to the input JSONL files. - output_csv_path: Path to save the output CSV file. - """ - output_dir = Path(output_csv_path).parent - output_dir.mkdir(parents=True, exist_ok=True) - - # Use paragraph content as key, value is the set of source question IDs - paragraphs_data = defaultdict(set) - print("Starting paragraph extraction (including non-supporting)...") - - for file_path in input_paths: - print(f"Processing file: {file_path}") - try: - with open(file_path, "r", encoding="utf-8") as infile: - for line_num, line in enumerate(infile, 1): - try: - data = json.loads(line) - main_question_id = data.get("id") - if not main_question_id: - print(f"Warning: Missing 'id' in line {line_num} of {file_path}") - continue - - for p in data.get("paragraphs", []): - title = p.get("title", "No Title") - text = p.get("paragraph_text", "") - content = f"{title}\n{text}".strip() - - if not content: - continue # Skip empty paragraphs - - paragraphs_data[content].add(main_question_id) - - except json.JSONDecodeError: - print(f"Warning: Skipping invalid JSON in line {line_num} of {file_path}") - except Exception as e: - print(f"Warning: Error processing line {line_num} in {file_path}: {e}") - except FileNotFoundError: - print(f"Error: Input file not found: {file_path}") - except Exception as e: - print(f"Error reading file {file_path}: {e}") - - print(f"Found {len(paragraphs_data)} unique paragraphs (supporting and non-supporting).") - - # Prepare data for DataFrame - output_list = [] - sorted_content = sorted(paragraphs_data.keys()) - for chunk_id, content in enumerate(sorted_content, 1): - question_ids = paragraphs_data[content] - metadata = {"source_question_ids": sorted(list(question_ids))} - output_list.append( - { - "chunk_id": chunk_id, - "content": content, - "metadata": json.dumps(metadata), # Store metadata as JSON string - } - ) - - if not output_list: - print("No paragraphs found to save.") - return - df = pd.DataFrame(output_list) - try: - df.to_csv(output_csv_path, index=False) - print(f"Successfully saved unique paragraphs to {output_csv_path}") - except Exception as e: - print(f"Error saving CSV file: {e}") - - -if __name__ == "__main__": - RAW_DIR = Path("data/raw") - PROCESSED_DIR = Path("data/processed") - - input_files = [ - str(RAW_DIR / "musique_ans_v1.0_train.jsonl"), - str(RAW_DIR / "musique_ans_v1.0_dev.jsonl"), - str(RAW_DIR / "musique_ans_v1.0_test.jsonl"), - ] - output_csv = str(PROCESSED_DIR / "paragraphs.csv") - - extract_unique_paragraphs(input_files, output_csv) diff --git a/scripts/train_data/prepare_musique_dev_jsonl.py b/scripts/train_data/prepare_musique_dev_jsonl.py deleted file mode 100644 index dc19d0c..0000000 --- a/scripts/train_data/prepare_musique_dev_jsonl.py +++ /dev/null @@ -1,155 +0,0 @@ -"""Prepares a deterministic sampled dev set (questions_dev.jsonl) from raw Musique dev data.""" - -import json -import math -import os -import re -from collections import defaultdict -from pathlib import Path - - -def transform_musique_dev_data(input_path: str, output_path: str, sample_config: dict) -> None: - """Transforms Musique dev data with deterministic stratified sampling using uniform selection from sorted lists. - - Reads dev data, categorizes by hop type (2, 3, 4), sorts categories by ID, - selects N samples uniformly spaced from each sorted category based on sample_config, - combines, sorts final list by ID, combines answers/aliases, extracts supporting paras, - and writes the transformed data to output_path. - - Args: - input_path: Path to the input JSONL file (e.g., data/raw/musique_ans_v1.0_dev.jsonl). - output_path: Path to the output JSONL file (e.g., data/processed/questions_dev.jsonl). - sample_config: Dictionary specifying samples per hop type (e.g., {"2hop": 20, "3hop": 15, "4hop": 15}). - """ - output_dir = Path(output_path).parent - output_dir.mkdir(parents=True, exist_ok=True) - - print(f"Reading all data from {input_path} for dev sampling...") - all_data = [] - try: - with open(input_path, "r", encoding="utf-8") as infile: - for line_num, line in enumerate(infile, 1): - try: - data = json.loads(line) - if "id" in data: - all_data.append(data) - else: - print(f"Warning: Skipping line {line_num} due to missing 'id' field in {input_path}") - except json.JSONDecodeError: - print(f"Warning: Skipping invalid JSON in line {line_num} of {input_path}") - except FileNotFoundError: - print(f"Error: Input file not found at {input_path}") - return - except Exception as e: - print(f"Error reading file {input_path}: {e}") - return - print(f"Read {len(all_data)} total samples from dev set.") - - # Categorize data by hop count (2hop, 3hop, 4hop) - categorized_data = defaultdict(list) - print("Categorizing data by hop type (2, 3, 4)...") - for data in all_data: - q_id = data["id"] - hop_type = None - if q_id.startswith("2hop"): - hop_type = "2hop" - elif q_id.startswith("3hop"): - hop_type = "3hop" - elif q_id.startswith("4hop"): - hop_type = "4hop" - - if hop_type: - categorized_data[hop_type].append(data) - - # Deterministic sampling using sorting and uniform index selection - final_sample_list = [] - total_target = sum(sample_config.values()) - print(f"Sampling deterministically via uniform selection from sorted lists to get {total_target} dev samples...") - - for hop_type, target_count in sample_config.items(): - available_samples = categorized_data.get(hop_type, []) - current_count = len(available_samples) - print(f" {hop_type}: Found {current_count} samples, need {target_count}.") - - if current_count == 0: - continue - - available_samples.sort(key=lambda x: x["id"]) - selected_samples_for_hop = [] - if current_count < target_count: - print(f" Warning: Not enough samples for {hop_type}. Taking all {current_count} sorted samples.") - selected_samples_for_hop = available_samples - elif target_count > 0: # Ensure target_count is positive before selecting - print(f" Selecting {target_count} samples uniformly from {current_count}...") - # Calculate indices using integer interpretation of evenly spaced points - indices_to_take = [ - int(i * (current_count - 1) / (target_count - 1)) if target_count > 1 else 0 - for i in range(target_count) - ] # Adjust index calc for edges - indices_to_take = sorted(list(set(indices_to_take))) # Ensure unique indices - # Simple fallback if uniqueness reduced count below target - while len(indices_to_take) < target_count and len(indices_to_take) < current_count: - next_val = indices_to_take[-1] + 1 - if next_val < current_count: - indices_to_take.append(next_val) - else: # Cannot add more unique indices - break - selected_samples_for_hop = [ - available_samples[idx] for idx in indices_to_take[:target_count] - ] # Select based on unique indices, capped at target - - final_sample_list.extend(selected_samples_for_hop) - - print(f"Selected {len(final_sample_list)} dev samples in total.") - - # Sort the final combined list by ID for consistent output order - print("Sorting the final combined dev sample list by ID...") - final_sample_list.sort(key=lambda x: x["id"]) - - # Process and write the selected samples - print(f"Processing and writing {len(final_sample_list)} selected dev samples to {output_path}...") - count = 0 - try: - with open(output_path, "w", encoding="utf-8") as outfile: - for data in final_sample_list: - try: - supporting_paragraphs = [ - p["paragraph_text"] for p in data.get("paragraphs", []) if p.get("is_supporting", False) - ] - main_answer = data.get("answer", "") - aliases = data.get("answer_aliases", []) - all_answers = [main_answer] + (aliases if isinstance(aliases, list) else []) - valid_answers = [str(ans).strip() for ans in all_answers if ans and str(ans).strip()] - unique_valid_answers = list(set(valid_answers)) # Keep unique, don't sort alphabetically - combined_answer_str = " OR ".join(unique_valid_answers) - - output_data = { - "id": data.get("id"), - "question": data.get("question"), - "answer": combined_answer_str, - "supporting_paragraphs": supporting_paragraphs, - } - outfile.write(json.dumps(output_data) + "\n") - count += 1 - except KeyError as e: - print(f"Skipping sample due to missing key {e}: {data.get('id')}") - print(f"Successfully processed and wrote {count} dev samples.") - except Exception as e: - print(f"An unexpected error occurred during writing: {e}") - - -if __name__ == "__main__": - # Define file paths relative to the project root - # Ensure this script is run from the project root or adjust paths accordingly - RAW_DIR = Path("data/raw") - PROCESSED_DIR = Path("data/processed") - - # Define sampling configuration for the dev set - DEV_SAMPLING_CONFIG = {"2hop": 20, "3hop": 15, "4hop": 15} # Total = 50 - - INPUT_FILE = RAW_DIR / "musique_ans_v1.0_dev.jsonl" - OUTPUT_FILE = PROCESSED_DIR / "questions_dev.jsonl" - - transform_musique_dev_data(str(INPUT_FILE), str(OUTPUT_FILE), DEV_SAMPLING_CONFIG) - - print(f"\nMusique DEV JSONL transformation and deterministic sampling complete.") diff --git a/scripts/train_data/prepare_musique_jsonl.py b/scripts/train_data/prepare_musique_jsonl.py deleted file mode 100644 index 74e41da..0000000 --- a/scripts/train_data/prepare_musique_jsonl.py +++ /dev/null @@ -1,172 +0,0 @@ -import json -import math # Keep math import -import os -import re # Import re for parsing ID -from collections import defaultdict -from pathlib import Path - -# import random # No longer needed -# SEED = 42 # No longer needed -# random.seed(SEED) # No longer needed - - -def transform_musique_data(input_path: str, output_path: str, sample_config: dict) -> None: - """Transforms Musique data with deterministic stratified sampling using uniform selection from sorted lists. - - Reads data, categorizes by detailed hop type, sorts categories by ID, - selects N samples uniformly spaced from each sorted category, - combines, sorts final list by ID, and writes to output. - - Args: - input_path: Path to the input JSONL file. - output_path: Path to the output JSONL file. - sample_config: Dictionary specifying samples per detailed hop type (e.g., {"2hop": 400, "3hop1": 150, ...}). - """ - output_dir = Path(output_path).parent - output_dir.mkdir(parents=True, exist_ok=True) - - print(f"Reading all data from {input_path} for sampling...") - all_data = [] - try: - with open(input_path, "r", encoding="utf-8") as infile: - for line_num, line in enumerate(infile, 1): - try: - data = json.loads(line) - if "id" in data: - all_data.append(data) - else: - print(f"Warning: Skipping line {line_num} due to missing 'id' field in {input_path}") - except json.JSONDecodeError: - print(f"Warning: Skipping invalid JSON in line {line_num} of {input_path}") - except FileNotFoundError: - print(f"Error: Input file not found at {input_path}") - return - except Exception as e: - print(f"Error reading file {input_path}: {e}") - return - print(f"Read {len(all_data)} total samples with IDs.") - - # Detailed Categorization by hop type - categorized_data = defaultdict(list) - print("Categorizing data by detailed hop type (e.g., 3hop1, 4hop2)...") - for data in all_data: - q_id = data["id"] - match = re.match(r"^(2hop|3hop[12]|4hop[123])__", q_id) - if match: - detailed_hop_type = match.group(1) - categorized_data[detailed_hop_type].append(data) - # else: # Optional: log if an ID doesn't match expected pattern - # print(f"Warning: ID {q_id} does not match expected hop pattern.") - - # Deterministic sampling using sorting and uniform index selection - final_sample_list = [] - total_target = sum(sample_config.values()) - print(f"Sampling deterministically via uniform selection from sorted lists to get {total_target} samples...") - # Check if all requested hop types exist in config - for hop_type in sample_config.keys(): - if hop_type not in categorized_data: - print(f"Warning: Hop type '{hop_type}' requested in config but not found in data.") - - for hop_type, target_count in sample_config.items(): - available_samples = categorized_data.get(hop_type, []) - current_count = len(available_samples) - print(f" {hop_type}: Found {current_count} samples, need {target_count}.") - - if current_count == 0: - continue - - # Sort the list for this category by ID - available_samples.sort(key=lambda x: x["id"]) - - selected_samples_for_hop = [] - if current_count < target_count: - print(f" Warning: Not enough samples for {hop_type}. Taking all {current_count} sorted samples.") - selected_samples_for_hop = available_samples - else: - # Select target_count indices spread uniformly across the available samples - print(f" Selecting {target_count} samples uniformly from {current_count}...") - # Calculate indices using integer interpretation of evenly spaced points - indices_to_take = [int(i * current_count / target_count) for i in range(target_count)] - # Ensure uniqueness in case of rounding issues with small numbers (though unlikely here) - indices_to_take = sorted(list(set(indices_to_take))) - # Adjust if rounding resulted in fewer than target_count unique indices - while len(indices_to_take) < target_count: - # This is a fallback, shouldn't happen if current_count >= target_count - # Add indices from the end if needed, avoiding duplicates - next_idx = indices_to_take[-1] + 1 - if next_idx < current_count and next_idx not in indices_to_take: - indices_to_take.append(next_idx) - else: # Should not be reachable if logic is sound - break - - # Select samples at the calculated indices - selected_samples_for_hop = [ - available_samples[idx] for idx in indices_to_take[:target_count] - ] # Ensure we take exactly target_count - - final_sample_list.extend(selected_samples_for_hop) - - print(f"Selected {len(final_sample_list)} samples in total.") - - # Sort the final combined list by ID for consistent output order - print("Sorting the final combined sample list by ID...") - final_sample_list.sort(key=lambda x: x["id"]) - - # Process and write the selected samples - print(f"Processing and writing {len(final_sample_list)} selected samples to {output_path}...") - count = 0 - try: - with open(output_path, "w", encoding="utf-8") as outfile: - for data in final_sample_list: - try: - supporting_paragraphs = [ - p["paragraph_text"] for p in data.get("paragraphs", []) if p.get("is_supporting", False) - ] - - main_answer = data.get("answer", "") - aliases = data.get("answer_aliases", []) - - all_answers = [main_answer] + (aliases if isinstance(aliases, list) else []) - valid_answers = [str(ans).strip() for ans in all_answers if ans and str(ans).strip()] - unique_valid_answers = list(set(valid_answers)) - - combined_answer_str = " OR ".join(unique_valid_answers) - - output_data = { - "id": data.get("id"), - "question": data.get("question"), - "answer": combined_answer_str, - "supporting_paragraphs": supporting_paragraphs, - } - outfile.write(json.dumps(output_data) + "\n") - count += 1 - except KeyError as e: - print(f"Skipping sample due to missing key {e}: {data.get('id')}") - print(f"Successfully processed and wrote {count} samples.") - except Exception as e: - print(f"An unexpected error occurred during writing: {e}") - - -if __name__ == "__main__": - # Define file paths - RAW_DIR = Path("data/raw") - PROCESSED_DIR = Path("data/processed") - - # Define detailed sampling configuration - SAMPLING_CONFIG = { - "2hop": 400, - "3hop1": 150, - "3hop2": 150, - "4hop1": 100, - "4hop2": 100, - "4hop3": 100, - } # Total = 1000 - - transform_musique_data( - str(RAW_DIR / "musique_ans_v1.0_train.jsonl"), str(PROCESSED_DIR / "questions.jsonl"), SAMPLING_CONFIG - ) - - print( - "\nMusique JSONL transformation and detailed deterministic sampling (uniform selection from sorted) complete." - ) - # Note: Dev/Test files are not processed by default with this sampling logic. diff --git a/third_party/FlashRAG b/third_party/FlashRAG deleted file mode 160000 index 7e60ab2..0000000 --- a/third_party/FlashRAG +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 7e60ab26825a452f8ee8eb19799d1cb6c1746326