chore: chore: remove unused code and dependencies

main
thinhlpg 4 weeks ago
parent 5eabd121a3
commit 89e07bc02d

@ -1,2 +1,3 @@
HF_TOKEN=<your-huggingface-token>
OPENROUTER_API_KEY=<your-openrouter-api-key>
TAVILY_API_KEY=<your-tavily-api-key>
SERPER_API_KEY=<your-serper-api-key>

1
.gitignore vendored

@ -19,6 +19,7 @@ logs/
data/
.gradio/
output*
llama.cpp*
# Byte-compiled / optimized / DLL files
__pycache__/

3
.gitmodules vendored

@ -1,3 +0,0 @@
[submodule "third_party/FlashRAG"]
path = third_party/FlashRAG
url = https://github.com/RUC-NLPIR/FlashRAG.git

@ -1,4 +1,4 @@
.PHONY: style quality install tensorboard clean fix update-worklog test data download-musique prepare-musique-jsonl extract-musique-paragraphs build-musique-index prepare-musique-index prepare-all-musique check-data prepare-dev-data ensure-unzip download-all-models serve-retriever serve-generator run-evaluation download-flashrag-data download-flashrag-index download-retriever-model download-generator-model serve-all run-full-evaluation evaluation-download-models prepare-serving serve-background stop-serving
.PHONY: style quality install clean fix test check-data simple-qa generate-data eval-base eval-lora download-checkpoint upload-checkpoint save-merged-16bit
# make sure to test the local checkout in scripts and not the pre-installed one
export PYTHONPATH = src
@ -11,7 +11,7 @@ test:
# Development dependencies
install:
pip install -e . && pip install -e third_party/FlashRAG
pip install -e .
# Code quality and style
style:
@ -28,159 +28,38 @@ fix:
ruff check --fix --line-length 119 --target-version py311 $(check_dirs)
isort $(check_dirs)
# TensorBoard
tensorboard:
tensorboard --logdir=trainer_output_*_runs --port=6006
# List available run directories
list-runs:
@echo "Available run directories:"
@ls -d trainer_output_*_runs 2>/dev/null || echo "No run directories found"
# Ensure unzip is available
ensure-unzip:
@which unzip > /dev/null || (echo "Installing unzip..." && sudo apt-get update && sudo apt-get install -y unzip)
@echo "✓ unzip is available"
# Data Preparation - One command to rule them all
data: download-musique prepare-musique-jsonl extract-musique-paragraphs build-musique-index prepare-dev-data check-data
@echo "✨ All data preparation complete! ✨"
# Index Preparation
prepare-musique-index: build-musique-index
@echo "Musique index preparation complete."
# Check Data
check-data:
@echo "Checking generated data files..."
python scripts/check_data.py
download-musique: ensure-unzip
@echo "Downloading Musique dataset..."
bash scripts/train_data/download_data_musique.sh
@echo "Musique dataset ready in ./data/raw/"
# Simple QA
simple-qa:
python scripts/simple_qa.py
prepare-musique-jsonl: download-musique
@echo "Preparing Musique data (JSONL)..."
python scripts/train_data/prepare_musique_jsonl.py
@echo "Processed Musique JSONL ready in ./data/processed/questions.jsonl"
# Generate data
generate-data:
python scripts/generate_data.py
extract-musique-paragraphs: download-musique
@echo "Extracting unique paragraphs from raw Musique data..."
python scripts/train_data/extract_musique_paragraphs.py
@echo "Musique paragraphs extracted to ./data/processed/paragraphs.csv"
# Evaluate base model
eval-base:
python scripts/eval_base.py
build-musique-index: extract-musique-paragraphs
@echo "Building Musique FAISS index from paragraphs..."
python scripts/train_data/build_musique_index.py
@echo "Musique FAISS index files saved to ./data/processed/"
# Evaluate LoRA model
eval-lora:
python scripts/eval_lora.py
# Combined Preparation
prepare-all-musique: data prepare-musique-index
@echo "All Musique data and index preparation complete."
# Download checkpoint
download-checkpoint:
python scripts/download_checkpoint.py
# Check Data
check-data:
@echo "Checking generated data files..."
python scripts/check_data.py
# Upload checkpoint
upload-checkpoint:
python scripts/upload_checkpoint.py
# Prepare Dev Data
prepare-dev-data: download-musique
@echo "Preparing Musique DEV data (JSONL)..."
python scripts/train_data/prepare_musique_dev_jsonl.py
@echo "Processed Musique DEV JSONL ready in ./data/processed/questions_dev.jsonl"
# ======= SERVING COMMANDS =======
# Prepare everything needed for serving (download models and data)
prepare-serving: download-all-models
@echo "✨ All models and data for serving prepared! ✨"
@echo "You can now run services with:"
@echo " make serve-retriever"
@echo " make serve-generator"
@echo " or both with separate terminals"
# Download all required models and data for serving
download-all-models: download-flashrag-data download-flashrag-index download-retriever-model download-generator-model
@echo "✨ All models and data downloaded! ✨"
# Download FlashRAG datasets
download-flashrag-data:
@echo "Downloading FlashRAG datasets..."
python scripts/serving/download_flashrag_datasets.py
@echo "FlashRAG datasets downloaded!"
# Download FlashRAG index
download-flashrag-index:
@echo "Downloading FlashRAG index..."
python scripts/serving/download_flashrag_index.py
@echo "FlashRAG index downloaded!"
# Download retriever model
download-retriever-model:
@echo "Downloading retriever model..."
python scripts/serving/download_retriever_model.py
@echo "Retriever model downloaded!"
# Download generator model
download-generator-model:
@echo "Downloading generator model..."
python scripts/serving/download_generator_model.py
@echo "Generator model downloaded!"
# Serve retriever
serve-retriever: download-retriever-model download-flashrag-index download-flashrag-data
@echo "Starting retriever service..."
python scripts/serving/serve_retriever.py --config scripts/serving/retriever_config.yaml
# Serve generator
serve-generator: download-generator-model
@echo "Starting generator service..."
python scripts/serving/serve_generator.py
# Start both services (retriever and generator) in the background
serve-background: prepare-serving
@echo "Starting both retriever and generator services in background..."
@mkdir -p logs
@echo "Starting retriever in background..."
@nohup python scripts/serving/serve_retriever.py --config scripts/serving/retriever_config.yaml > logs/retriever.log 2>&1 &
@echo "Retriever started! PID: $$!"
@echo "Starting generator in background..."
@nohup python scripts/serving/serve_generator.py > logs/generator.log 2>&1 &
@echo "Generator started! PID: $$!"
@echo "✨ Both services running in background! ✨"
@echo "Check logs in logs/retriever.log and logs/generator.log"
@echo "To stop services: make stop-serving"
# Stop all serving processes
stop-serving:
@echo "Stopping all serving processes..."
@pkill -f 'python scripts/serving/serve_' || echo "No serving processes found"
@echo "✅ All services stopped!"
# Serve all components
serve-all: download-all-models
@echo "Starting all services..."
@echo "Please run these commands in separate terminals:"
@echo " make serve-retriever"
@echo " make serve-generator"
@echo ""
@echo "Or run both in background with one command:"
@echo " make serve-background"
@echo ""
@echo "To stop background services:"
@echo " make stop-serving"
# ======= EVALUATION COMMANDS =======
# Download models needed for evaluation
evaluation-download-models: download-all-models
@echo "✨ All models for evaluation downloaded! ✨"
# Run evaluation script
run-evaluation:
@echo "Running evaluation..."
python scripts/evaluation/run_eval.py --config scripts/evaluation/eval_config.yaml
@echo "Evaluation complete! Results in scripts/evaluation/output_logs/"
# Run complete evaluation pipeline
run-full-evaluation: evaluation-download-models run-evaluation
@echo "✨ Full evaluation pipeline complete! ✨"
# Save merged 16bit model
save-merged-16bit:
python scripts/save_merged_16bit.py
# Clean up
clean:
@ -196,12 +75,4 @@ clean:
find . -type d -name ".coverage" -exec rm -r {} +
find . -type d -name "htmlcov" -exec rm -r {} +
find . -type d -name "build" -exec rm -r {} +
find . -type d -name "dist" -exec rm -r {} +
rm -rf ./data/raw ./data/processed # Clean raw and processed data
# Clean up the old faiss_index directory if it exists
rm -rf ./data/processed/faiss_index
# Update worklog in GitHub issue
update-worklog:
gh api -X PATCH /repos/menloresearch/DeepSearch/issues/comments/2743047160 \
-f body="$$(cat docs/00_worklog.md)" | cat && kill -9 $$PPID
find . -type d -name "dist" -exec rm -r {} +

@ -1,46 +0,0 @@
# ------------------------------------------------Environment Settings------------------------------------------------#
# Directory paths for data and outputs
data_dir: "/mnt/nas/thinhlpg/code/DeepSearch/data/flashrag_datasets/"
save_dir: "/mnt/nas/thinhlpg/code/DeepSearch/logs"
# Seed for reproducibility
seed: 2024
# Whether save intermediate data
save_intermediate_data: True
save_note: 'experiment'
# -------------------------------------------------Retrieval Settings------------------------------------------------#
# If set the remote url, the retriever will be a remote retriever and ignore following settings
use_remote_retriever: True
remote_retriever_url: "localhost:8001"
instruction: ~ # instruction for retrieval model
retrieval_topk: 5 # number of retrieved documents
retrieval_batch_size: 256 # batch size for retrieval
retrieval_use_fp16: True # whether to use fp16 for retrieval model
retrieval_query_max_length: 128 # max length of the query
save_retrieval_cache: False # whether to save the retrieval cache
use_retrieval_cache: False # whether to use the retrieval cache
retrieval_cache_path: ~ # path to the retrieval cache
retrieval_pooling_method: ~ # set automatically if not provided
# -------------------------------------------------Generator Settings------------------------------------------------#
framework: sgl_remote # inference frame work of LLM, supporting: 'hf','vllm','fschat'
sgl_remote_url: "localhost:8002"
generator_model: "janhq/250404-llama-3.2-3b-instruct-grpo-03-s250" # name or path of the generator model, for laoding tokenizer
generator_max_input_len: 2048 # max length of the input
generation_params:
do_sample: False
max_tokens: 8192
# -------------------------------------------------Evaluation Settings------------------------------------------------#
# Metrics to evaluate the result
metrics: [ 'em','f1','acc','precision','recall']
# Specify setting for metric, will be called within certain metrics
metric_setting:
retrieval_recall_topk: 5
save_metric_score: True # whether to save the metric score into txt file

@ -1,656 +0,0 @@
import argparse
import json
import os
import re
import time
from copy import deepcopy
from datetime import datetime
from functools import wraps
import numpy as np
import requests
from flashrag.config import Config
from flashrag.generator.generator import BaseGenerator
from flashrag.pipeline import BasicPipeline
from flashrag.retriever.retriever import BaseTextRetriever
from flashrag.utils import get_dataset
from transformers import AutoTokenizer
from config import logger
from src.agent import Agent, AgenticOutputs
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
def retry(max_retries=10, sleep=1):
"""Decorator to retry a function with exponential backoff."""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
func_name = func.__name__
for attempt in range(max_retries):
try:
result = func(*args, **kwargs)
return result
except Exception as e:
logger.warning(f"Attempt {attempt + 1} of {func_name} failed: {e}")
if attempt == max_retries - 1:
logger.error(f"Function {func_name} failed after {max_retries} retries.", exc_info=True)
raise e
backoff_time = sleep * (2**attempt)
logger.info(f"Retrying {func_name} in {backoff_time:.2f} seconds...")
time.sleep(backoff_time)
logger.error(f"Function {func_name} retry logic finished unexpectedly.")
return None
return wrapper
return decorator
class RemoteRetriever(BaseTextRetriever):
"""A wrapper for remote retriever service with retry logic and logging."""
def __init__(self, config: Config):
"""Initializes the RemoteRetriever."""
super().__init__(config)
self.remote_url = f"http://{getattr(config, 'remote_retriever_url', 'localhost:8001')}"
self.topk = getattr(config, "retriever_topk", 5)
logger.info(f"🔗 Remote retriever URL: {self.remote_url}")
@retry(max_retries=3, sleep=2)
def _search(self, query: str, num: int | None = None, return_score: bool = False) -> list[dict]:
"""Search for documents using the remote retriever service."""
num = num if num is not None else self.topk
url = f"{self.remote_url}/search"
try:
response = requests.post(
url,
json={"query": query, "top_n": num, "return_score": return_score},
timeout=30,
)
response.raise_for_status()
results = response.json()
return results
except requests.exceptions.Timeout:
logger.error(f"Search request timed out after 30 seconds for query: {query[:50]}...")
raise
except requests.exceptions.ConnectionError:
logger.error(f"Could not connect to search service at {url}")
raise
except requests.exceptions.RequestException as e:
logger.error(f"Search request failed: {e}", exc_info=True)
raise
except Exception as e:
logger.error(f"Unexpected search error: {str(e)}", exc_info=True)
raise
@retry(max_retries=3, sleep=2)
def _batch_search(
self, queries: list[str], num: int | None = None, return_score: bool = False
) -> list[list[dict]]:
"""Batch search for documents using the remote retriever service."""
num = num if num is not None else self.topk
url = f"{self.remote_url}/batch_search"
try:
response = requests.post(
url,
json={"query": queries, "top_n": num, "return_score": return_score},
timeout=60,
)
response.raise_for_status()
results = response.json()
return results
except requests.exceptions.Timeout:
logger.error(f"Batch search request timed out after 60 seconds for {len(queries)} queries.")
raise
except requests.exceptions.ConnectionError:
logger.error(f"Could not connect to batch search service at {url}")
raise
except requests.exceptions.RequestException as e:
logger.error(f"Batch search request failed: {e}", exc_info=True)
raise
except Exception as e:
logger.error(f"Unexpected batch search error: {str(e)}", exc_info=True)
raise
class ReSearchPipeline(BasicPipeline):
"""Pipeline for ReSearch method using Agent for generation and tool use."""
def __init__(
self, config: Config, retriever: BaseTextRetriever | None = None, generator: BaseGenerator | None = None
):
"""Initializes the ReSearchPipeline."""
super().__init__(config)
logger.info("🔧 Initializing ReSearchPipeline...")
self.retriever = retriever or RemoteRetriever(config)
self.generator = generator or SGLRemoteGenerator(config)
try:
self.tokenizer = AutoTokenizer.from_pretrained(config.generator_model_path, trust_remote_code=True)
if not self.tokenizer.pad_token:
logger.warning("Tokenizer does not have a pad token; setting to eos_token.")
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.padding_side = "left"
logger.info("✅ Tokenizer initialized.")
except Exception as e:
logger.error(f"Failed to initialize tokenizer: {e}", exc_info=True)
raise
tokenizer_name = self.tokenizer.name_or_path.lower()
if "deepseek-ai/deepseek-r1-distill" in tokenizer_name:
adapter = R1DistilTokenizerAdapter()
elif "llama" in tokenizer_name:
adapter = LlamaTokenizerAdapter()
else:
logger.warning(f"Unknown tokenizer type '{tokenizer_name}', defaulting to R1DistilTokenizerAdapter.")
adapter = R1DistilTokenizerAdapter()
logger.info(f"🔩 Using Tokenizer Adapter: {type(adapter).__name__}")
def retriever_search(query: str, return_type=str, results: int = 5):
try:
search_results = self.retriever._search(query, num=results)
return self.format_search_results(search_results)
except Exception as e:
logger.error(f"Error during agent's retriever search for query '{query[:50]}...': {e}", exc_info=True)
return "<information>Search failed due to an internal error."
self.agent = Agent(adapter, search_fn=retriever_search)
logger.info("✅ Agent initialized.")
logger.info("✅ ReSearchPipeline initialized successfully.")
def format_search_results(self, search_results: list[dict]) -> str:
"""Formats search results into a string for the agent prompt."""
if not search_results:
return "<information>No results found.</information>"
max_content_len = 500
formatted = "\n-------\n".join(
[
f"Result {i + 1}: {r.get('contents', 'N/A')[:max_content_len]}{'...' if len(r.get('contents', '')) > max_content_len else ''}"
for i, r in enumerate(search_results)
]
)
formatted_str = f"<information>{formatted}</information>"
return formatted_str
def extract_search_query(self, text: str) -> str | None:
"""Extract search query from text between <search> tags."""
pattern = re.compile(r"<search>(.*?)</search>", re.DOTALL)
matches = pattern.findall(text)
if matches:
query = matches[-1].strip()
return query
return None
def extract_answer(self, text: str) -> str | None:
"""Extract answer from text between <answer> tags."""
pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL)
matches = pattern.findall(text)
if matches:
answer = matches[-1].strip()
return answer
return None
def run(self, dataset, do_eval: bool = True, pred_process_fun=None):
"""Runs the ReSearch pipeline on the dataset using the Agent."""
logger.info(f"🏃 Starting ReSearch pipeline run with {len(dataset)} items...")
try:
questions = [item.question if hasattr(item, "question") else item["question"] for item in dataset]
except (KeyError, AttributeError, TypeError) as e:
logger.error(f"Failed to extract questions from dataset items. Error: {e}", exc_info=True)
logger.error("Ensure dataset items have a 'question' key or attribute.")
return dataset
agent_max_generations = getattr(self.config, "agent_max_generations", 32)
generator_max_output_len = getattr(self.config, "generator_max_output_len", 24576)
try:
logger.info(f"🤖 Running agent inference for {len(questions)} questions...")
agent_outputs: AgenticOutputs = self.agent.run_agent(
generate_fn=self.generator.generate,
tokenizer=self.tokenizer,
questions=questions,
max_generations=agent_max_generations,
max_new_tokens=generator_max_output_len,
)
final_responses = agent_outputs.final_response_str
logger.info(f"✅ Agent inference completed. Received {len(final_responses)} final responses.")
except Exception as e:
logger.error(f"Agent run failed during inference: {e}", exc_info=True)
logger.warning("Agent run failed, attempting evaluation with potentially incomplete results.")
for item in dataset:
if hasattr(item, "update_output"):
item.update_output("pred", "AGENT_ERROR")
elif isinstance(item, dict):
item["pred"] = "AGENT_ERROR"
logger.info("📝 Extracting answers and updating dataset items...")
num_updated = 0
num_missing_answers = 0
if len(final_responses) == len(dataset):
for i, item in enumerate(dataset):
response = final_responses[i]
answer = self.extract_answer(response)
pred_to_save = answer if answer is not None else ""
if answer is None:
num_missing_answers += 1
if hasattr(item, "update_output"):
item.update_output("pred", pred_to_save)
item.update_output("final_response", response)
num_updated += 1
elif isinstance(item, dict):
item["pred"] = pred_to_save
item["final_response"] = response
num_updated += 1
else:
logger.warning(f"Item {i} has unknown type {type(item)}, cannot update with prediction.")
logger.info(f"Updated {num_updated}/{len(dataset)} dataset items with predictions.")
if num_missing_answers > 0:
logger.warning(f"{num_missing_answers} items had no <answer> tag.")
else:
logger.error(
f"Mismatch between dataset size ({len(dataset)}) and number of agent responses ({len(final_responses)}). Cannot reliably update dataset."
)
for item in dataset:
if hasattr(item, "update_output"):
item.update_output("pred", "RESPONSE_COUNT_MISMATCH")
elif isinstance(item, dict):
item["pred"] = "RESPONSE_COUNT_MISMATCH"
if do_eval:
logger.info("📊 Evaluating results using BasicPipeline.evaluate...")
try:
dataset = self.evaluate(dataset, do_eval=True, pred_process_fun=pred_process_fun)
logger.info("✅ Evaluation completed via base class method.")
except Exception as e:
logger.error(f"Error during BasicPipeline.evaluate: {e}", exc_info=True)
logger.warning("Evaluation may be incomplete.")
else:
logger.info("Skipping evaluation step as do_eval=False.")
logger.info("✅ ReSearch pipeline run finished.")
return dataset
class SGLRemoteGenerator(BaseGenerator):
"""Class for decoder-only generator, based on SGLang remote service."""
def __init__(self, config: Config):
"""Initializes the SGLRemoteGenerator."""
super().__init__(config)
logger.info("🔧 Initializing SGLRemoteGenerator...")
sgl_url = getattr(config, "sgl_remote_url", "localhost:8002")
self.sgl_remote_url = f"http://{sgl_url}/generate"
self.health_check_url = f"http://{sgl_url}/health"
logger.info(f"🔗 Remote Generator URL: {self.sgl_remote_url}")
self.model_path = getattr(config, "generator_model_path", None)
if not self.model_path:
logger.error("generator_model_path not found in config!")
raise ValueError("generator_model_path is required for SGLRemoteGenerator")
try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
logger.info("✅ Tokenizer loaded for generator.")
except Exception as e:
logger.error(f"Failed to load tokenizer for generator from {self.model_path}: {e}", exc_info=True)
raise
self.generation_params = getattr(config, "generation_params", {})
self.config = config
self._check_health()
def _check_health(self):
"""Checks the health of the remote generator service."""
try:
test_response = requests.get(self.health_check_url, timeout=5)
test_response.raise_for_status()
logger.info("✅ Remote generator service is available")
except requests.exceptions.RequestException as e:
logger.error(f"Could not connect or verify remote generator service at {self.health_check_url}: {str(e)}")
logger.warning("Please ensure the SGLang service is running and accessible.")
@retry(max_retries=5, sleep=2)
def generate(
self,
input_list: list[str] | str,
return_raw_output: bool = False,
return_scores: bool = False,
**params,
) -> list[str] | tuple[list[str], list[list[float]]] | list[dict]:
"""Generates text using the remote SGLang service."""
if isinstance(input_list, str):
input_list = [input_list]
if not isinstance(input_list, list) or not all(isinstance(item, str) for item in input_list):
raise ValueError("Input must be a string or a list of strings.")
batch_size = len(input_list)
data_to_remote = {"text": input_list}
effective_params = deepcopy(self.generation_params)
effective_params.update(params)
curr_sampling_params = {}
if effective_params.get("do_sample", True) is False:
curr_sampling_params["temperature"] = 0.0
else:
curr_sampling_params["temperature"] = effective_params.get(
"temperature", getattr(self.config, "temperature", 0.7)
)
default_max_tokens = getattr(self.config, "generator_max_output_len", 1024)
curr_sampling_params["max_new_tokens"] = effective_params.get("max_new_tokens", default_max_tokens)
stop_sequences = effective_params.get("stop", [])
if isinstance(stop_sequences, str):
stop_sequences = [stop_sequences]
if stop_sequences:
curr_sampling_params["stop"] = stop_sequences
keys_to_remove = ["do_sample", "temperature", "max_new_tokens", "stop"]
for key in keys_to_remove:
effective_params.pop(key, None)
if "top_p" in effective_params:
curr_sampling_params["top_p"] = effective_params["top_p"]
if "top_k" in effective_params:
curr_sampling_params["top_k"] = effective_params["top_k"]
data_to_remote["sampling_params"] = curr_sampling_params
if return_scores:
data_to_remote["return_logprob"] = True
data_to_remote["top_logprobs_num"] = getattr(self.config, "top_logprobs_num", 2)
try:
response = requests.post(
self.sgl_remote_url, json=data_to_remote, timeout=120, headers={"Content-Type": "application/json"}
)
response.raise_for_status()
response_list = response.json()
if return_raw_output:
return response_list
generated_text = []
for item in response_list:
text = item.get("text", "")
finish_reason = item.get("meta_info", {}).get("finish_reason", {})
matched_stop = finish_reason.get("matched")
if matched_stop and curr_sampling_params.get("stop") and matched_stop in curr_sampling_params["stop"]:
text += matched_stop
generated_text.append(text)
if return_scores:
scores = []
for resp_item in response_list:
logprobs_list = resp_item.get("meta_info", {}).get("output_token_logprobs", [])
token_scores = [
np.exp(logprob[0]) if (logprob and len(logprob) > 0) else 0.0 for logprob in logprobs_list
]
scores.append(token_scores)
return generated_text, scores
else:
return generated_text
except requests.exceptions.Timeout:
logger.error("Generation request timed out after 120 seconds.")
raise
except requests.exceptions.ConnectionError:
logger.error(f"Could not connect to remote generator service at {self.sgl_remote_url}.")
raise
except requests.exceptions.RequestException as e:
logger.error(f"Network error during generation: {str(e)}", exc_info=True)
raise
except json.JSONDecodeError:
response_text = "Unknown (error occurred before response object assignment)"
if "response" in locals() and hasattr(response, "text"):
response_text = response.text[:500]
logger.error(
f"Failed to decode JSON response from {self.sgl_remote_url}. Response text: {response_text}...",
exc_info=True,
)
raise
except Exception as e:
logger.error(f"Unexpected error during generation: {str(e)}", exc_info=True)
raise
def load_dataset_items(config: Config, split: str) -> list[dict | object]:
"""Loads dataset items using flashrag's get_dataset."""
logger.info(f"📚 Loading dataset: {config.dataset_name}, Split: {split}")
try:
all_splits = get_dataset(config)
if split not in all_splits:
logger.error(
f"Split '{split}' not found in dataset '{config.dataset_name}'. Available splits: {list(all_splits.keys())}"
)
return []
dataset_items = all_splits[split]
logger.info(f"Successfully loaded {len(dataset_items)} items for split '{split}'.")
return dataset_items
except FileNotFoundError:
logger.error(
f"Dataset files not found for '{config.dataset_name}' in '{config.data_dir}'. Check config and paths."
)
return []
except Exception as e:
logger.error(f"Error loading dataset using get_dataset: {e}", exc_info=True)
return []
def save_results(args: argparse.Namespace, config: Config, result_dataset, run_duration: float):
"""Saves summary and debug information."""
logger.info("💾 Saving results...")
summary_file = os.path.join(args.save_dir, f"{args.save_note}_summary.txt")
debug_file = os.path.join(args.save_dir, f"{args.save_note}_debug.json")
num_items = len(result_dataset)
logger.info(f"Saving summary results to {summary_file}...")
try:
with open(summary_file, "w", encoding="utf-8") as f:
f.write("EVALUATION SUMMARY\n")
f.write("=================\n\n")
f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"Run Duration: {run_duration:.2f} seconds\n")
f.write(f"Dataset: {config.dataset_name} ({args.split} split)\n")
f.write(f"Model: {config.generator_model_path}\n")
f.write(f"Retriever: {config.remote_retriever_url}\n")
f.write(f"Agent Max Generations: {getattr(config, 'agent_max_generations', 'N/A')}\n")
f.write(f"Generator Max Output Len: {getattr(config, 'generator_max_output_len', 'N/A')}\n\n")
f.write(f"Total items processed: {num_items}\n")
f.write("\nNote: Verification was skipped in this run.\n")
f.write("Note: Overall metrics (like EM, F1) are usually printed to console by evaluate method.\n")
logger.info(f"✅ Summary saved to {summary_file}")
except Exception as e:
logger.error(f"Error saving summary file '{summary_file}': {e}", exc_info=True)
logger.info(f"Saving debug information (predictions & responses) to {debug_file}...")
try:
debug_data = []
for i, item in enumerate(result_dataset):
item_data: dict[str, object] = {}
def get_item_value(data_item, key_or_attr: str) -> str | int | float | list | bool | None:
if isinstance(data_item, dict):
return data_item.get(key_or_attr)
elif hasattr(data_item, key_or_attr):
return getattr(data_item, key_or_attr)
return None
item_data["item_index"] = i
item_data["question"] = get_item_value(item, "question")
item_data["prediction"] = get_item_value(item, "pred")
item_data["final_response"] = get_item_value(item, "final_response")
gt_answer_val = None
try:
gt_answer_val = get_item_value(item, "answer")
if gt_answer_val is None:
answers_list = get_item_value(item, "answers")
if isinstance(answers_list, list) and answers_list:
raw_ans = answers_list[0]
if isinstance(raw_ans, (str, int, float, bool)):
gt_answer_val = raw_ans
else:
gt_answer_val = str(raw_ans)
elif not isinstance(gt_answer_val, (str, int, float, bool)):
gt_answer_val = str(gt_answer_val)
except Exception as e:
logger.warning(f"Could not safely get ground truth for item {i}: {e}")
gt_answer_val = "ERROR_GETTING_ANSWER"
item_data["ground_truth"] = gt_answer_val
eval_score_val = None
try:
eval_score_val = get_item_value(item, "score")
if not isinstance(eval_score_val, (str, int, float, bool, type(None))):
eval_score_val = str(eval_score_val)
except Exception as e:
logger.warning(f"Could not safely get score for item {i}: {e}")
eval_score_val = "ERROR_GETTING_SCORE"
item_data["eval_score"] = eval_score_val
debug_data.append(item_data)
with open(debug_file, "w", encoding="utf-8") as f:
json.dump(debug_data, f, indent=2, ensure_ascii=False)
logger.info(f"✅ Debug information saved to {debug_file}")
except Exception as e:
logger.error(f"Error saving debug file '{debug_file}': {e}", exc_info=True)
def research(args: argparse.Namespace, config: Config):
"""Main function to run the research evaluation pipeline."""
logger.info("🚀 Starting research pipeline execution...")
start_time = time.time()
test_data = load_dataset_items(config, args.split)
if not test_data:
logger.error("Failed to load test data. Exiting.")
return
try:
logger.info("🏗️ Building ReSearchPipeline...")
pipeline = ReSearchPipeline(config)
logger.info("✅ Pipeline built successfully.")
except Exception as e:
logger.error(f"Failed to initialize ReSearchPipeline: {e}", exc_info=True)
return
agent_max_generations = getattr(config, "agent_max_generations", 32)
generator_max_output_len = getattr(config, "generator_max_output_len", 24576)
try:
logger.info("🏃 Starting pipeline run...")
result_dataset = pipeline.run(test_data, do_eval=True)
logger.info("✅ Pipeline run completed.")
except Exception as e:
logger.error(f"Error during pipeline run: {e}", exc_info=True)
result_dataset = test_data
logger.warning("Pipeline run failed, attempting to save inputs/partial results.")
run_duration = time.time() - start_time
logger.info(f"Total run duration: {run_duration:.2f} seconds.")
save_results(args, config, result_dataset, run_duration)
logger.info("🏁 Research pipeline execution finished.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Running ReSearch Evaluation Pipeline")
parser.add_argument(
"--config_path", type=str, default="./eval_config.yaml", help="Path to the main FlashRAG config file."
)
parser.add_argument(
"--dataset_name",
type=str,
default="bamboogle",
help="Name of the dataset (must match config or data_dir structure).",
)
parser.add_argument(
"--split", type=str, default="test", help="Dataset split to evaluate (e.g., test, validation)."
)
parser.add_argument("--save_dir", type=str, default="./output_logs", help="Directory to save logs and results.")
parser.add_argument("--save_note", type=str, default="research_run", help="A note to prepend to saved filenames.")
parser.add_argument("--data_dir", type=str, help="Override data directory specified in config.")
parser.add_argument(
"--sgl_remote_url", type=str, help="Override SGLang remote generator URL (e.g., localhost:8002)."
)
parser.add_argument(
"--remote_retriever_url", type=str, help="Override remote retriever URL (e.g., localhost:8001)."
)
parser.add_argument("--generator_model_path", type=str, help="Override generator model path specified in config.")
parser.add_argument("--retriever_topk", type=int, help="Override retriever top K.")
parser.add_argument("--generator_max_output_len", type=int, help="Override generator max output length.")
parser.add_argument("--agent_max_generations", type=int, help="Override agent max interaction turns.")
args = parser.parse_args()
logger.info(f"Starting evaluation script with arguments: {args}")
try:
os.makedirs(args.save_dir, exist_ok=True)
logger.info(f"💾 Logs and results will be saved to: {args.save_dir}")
except OSError as e:
logger.error(f"Could not create save directory '{args.save_dir}': {e}", exc_info=True)
exit(1)
config_overrides = {
k: v
for k, v in vars(args).items()
if v is not None
and k
not in [
"config_path",
"dataset_name",
"split",
"save_dir",
"save_note",
]
}
logger.info(f"🔧 Loading configuration from: {args.config_path}")
try:
config = Config(args.config_path, config_dict=config_overrides)
config.dataset_name = args.dataset_name
if args.data_dir:
config.data_dir = args.data_dir
logger.info(f"Effective data_dir: {getattr(config, 'data_dir', 'N/A')}")
logger.info(f"Effective generator_model_path: {getattr(config, 'generator_model_path', 'N/A')}")
logger.info(f"Effective sgl_remote_url: {getattr(config, 'sgl_remote_url', 'N/A')}")
logger.info(f"Effective remote_retriever_url: {getattr(config, 'remote_retriever_url', 'N/A')}")
logger.info("✅ Config loaded and potentially overridden by CLI arguments.")
config["dataset_path"] = os.path.join(config.data_dir, config.dataset_name)
except FileNotFoundError:
logger.error(f"Config file not found at '{args.config_path}'. Please check the path.")
exit(1)
except Exception as e:
logger.error(f"Error loading or processing configuration: {e}", exc_info=True)
exit(1)
research(args, config)

@ -1,68 +0,0 @@
import argparse
import os
import zipfile
from dotenv import load_dotenv
from huggingface_hub import snapshot_download
from config import DATA_DIR
def parse_args() -> argparse.Namespace:
"""Parse command line arguments.
Returns:
argparse.Namespace: Parsed arguments
"""
parser = argparse.ArgumentParser(description="Download FlashRAG datasets from HuggingFace Hub")
parser.add_argument(
"--repo-id",
type=str,
default="RUC-NLPIR/FlashRAG_datasets",
help="HuggingFace repository IDs",
)
parser.add_argument(
"--local-dir",
type=str,
default=DATA_DIR / "flashrag_datasets",
help="Local directory to save model",
)
return parser.parse_args()
def main():
"""Main function to download model."""
args = parse_args()
load_dotenv(override=True)
# Configuration
HF_TOKEN = os.getenv("HF_TOKEN")
ALLOW_PATTERNS = [
"*retrieval-corpus*",
"*bamboogle*",
"*nq*",
]
# Download the model
snapshot_download(
token=HF_TOKEN,
repo_id=args.repo_id,
local_dir=args.local_dir,
repo_type="dataset",
# ignore_patterns=IGNORE_PATTERNS,
allow_patterns=ALLOW_PATTERNS,
)
# unzip data/flashrag_datasets/retrieval-corpus/wiki18_100w.zip
print("Unzipping wiki18_100w.zip. Might take a while...")
zip_file_path = os.path.join(args.local_dir, "retrieval-corpus", "wiki18_100w.zip")
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
zip_ref.extractall(args.local_dir)
print(f"✅ Done: {args.repo_id} -> {args.local_dir}")
if __name__ == "__main__":
main()

@ -1,43 +0,0 @@
"""Download and extract FlashRAG index."""
import os
import zipfile
import requests
from tqdm import tqdm
from config import DATA_DIR
# Constants
URL = "https://www.modelscope.cn/datasets/hhjinjiajie/FlashRAG_Dataset/resolve/master/retrieval_corpus/wiki18_100w_e5_index.zip"
ZIP_NAME = "wiki18_100w_e5_index.zip"
zip_path = DATA_DIR / ZIP_NAME
# Download with progress bar
print("📥 Downloading index...")
response = requests.get(URL, stream=True)
total_size = int(response.headers.get("content-length", 0))
with (
open(zip_path, "wb") as f,
tqdm(
desc=ZIP_NAME,
total=total_size,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar,
):
for data in response.iter_content(chunk_size=1024):
size = f.write(data)
bar.update(size)
# Extract
print("📦 Extracting index...")
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(DATA_DIR)
# Clean up zip
os.remove(zip_path)
print("✅ Download and extraction completed successfully!")
print(f"Index file is at: {DATA_DIR}/data00/jiajie_jin/flashrag_indexes/wiki_dpr_100w/e5_flat_inner.index")

@ -1,54 +0,0 @@
import argparse
import os
from dotenv import load_dotenv
from huggingface_hub import snapshot_download
from config import GENERATOR_MODEL_DIR, GENERATOR_MODEL_REPO_ID
def parse_args() -> argparse.Namespace:
"""Parse command line arguments.
Returns:
argparse.Namespace: Parsed arguments
"""
parser = argparse.ArgumentParser(description="Download model from HuggingFace Hub")
parser.add_argument(
"--repo-id",
type=str,
default=GENERATOR_MODEL_REPO_ID,
help="HuggingFace repository ID",
)
parser.add_argument(
"--local-dir",
type=str,
default=GENERATOR_MODEL_DIR,
help="Local directory to save model",
)
return parser.parse_args()
def main():
"""Main function to download model."""
args = parse_args()
load_dotenv(override=True)
# Configuration
HF_TOKEN = os.getenv("HF_TOKEN")
print("Downloading model to", args.local_dir)
# Download the model
snapshot_download(
token=HF_TOKEN,
repo_id=args.repo_id,
local_dir=args.local_dir,
repo_type="model",
)
print(f"✅ Done: {args.repo_id} -> {args.local_dir}")
if __name__ == "__main__":
main()

@ -1,54 +0,0 @@
import argparse
import os
from dotenv import load_dotenv
from huggingface_hub import snapshot_download
from config import RETRIEVER_MODEL_DIR, RETRIEVER_MODEL_REPO_ID
def parse_args() -> argparse.Namespace:
"""Parse command line arguments.
Returns:
argparse.Namespace: Parsed arguments
"""
parser = argparse.ArgumentParser(description="Download model from HuggingFace Hub")
parser.add_argument(
"--repo-id",
type=str,
default=RETRIEVER_MODEL_REPO_ID,
help="HuggingFace repository ID",
)
parser.add_argument(
"--local-dir",
type=str,
default=RETRIEVER_MODEL_DIR,
help="Local directory to save model",
)
return parser.parse_args()
def main():
"""Main function to download model."""
args = parse_args()
load_dotenv(override=True)
# Configuration
HF_TOKEN = os.getenv("HF_TOKEN")
print("Downloading model to", args.local_dir)
# Download the model
snapshot_download(
token=HF_TOKEN,
repo_id=args.repo_id,
local_dir=args.local_dir,
repo_type="model",
)
print(f"✅ Done: {args.repo_id} -> {args.local_dir}")
if __name__ == "__main__":
main()

@ -1,9 +0,0 @@
# ------------------------------------------------Environment Settings------------------------------------------------#
gpu_id: "0"
# -------------------------------------------------Retrieval Settings------------------------------------------------#
# If set the name, the model path will be find in global paths
retrieval_method: "e5" # name or path of the retrieval model.
index_path: "/mnt/nas/thinhlpg/data/data00/jiajie_jin/flashrag_indexes/wiki_dpr_100w/e5_flat_inner.index" # path to the indexed file
faiss_gpu: False # whether use gpu to hold index
corpus_path: "/mnt/nas/thinhlpg/code/DeepSearch/data/flashrag_datasets/wiki18_100w.jsonl" # path to corpus in '.jsonl' format that store the documents

@ -1,127 +0,0 @@
import subprocess
import sys
from pathlib import Path
# Add project root to sys.path to allow importing config
# Assuming the script is at DeepSearch/scripts/serving/serve_generator.py
# The project root (DeepSearch) is parents[2]
PROJ_ROOT = Path(__file__).resolve().parents[2]
if str(PROJ_ROOT) not in sys.path:
sys.path.append(str(PROJ_ROOT))
# Import after adjusting sys.path
try:
from config import (
GENERATOR_MODEL_REPO_ID,
GENERATOR_SERVER_PORT,
MODEL_CONFIG,
logger,
)
except ImportError as e:
# Use print here as logger might not be available if import failed
print(
f"Error importing config: {e}. Make sure config.py is in the project root ({PROJ_ROOT}) and added to sys.path."
)
sys.exit(1)
def launch_sglang_server(
model_id: str,
port: int,
context_length: int,
host: str = "0.0.0.0",
dtype: str = "bfloat16",
) -> None:
"""Launches the SGLang server using specified configurations.
Args:
model_id: The Hugging Face repository ID of the model.
port: The port number for the server.
context_length: The maximum context length for the model.
host: The host address for the server.
dtype: The data type for the model (e.g., 'bfloat16', 'float16').
"""
command = [
sys.executable, # Use the current Python interpreter
"-m",
"sglang.launch_server",
"--model-path",
model_id,
"--context-length",
str(context_length),
"--enable-metrics",
"--dtype",
dtype,
"--host",
host,
"--port",
str(port),
"--mem-fraction-static",
"0.5",
"--trust-remote-code",
# Recommended by SGLang for stability sometimes
"--disable-overlap",
# Can sometimes cause issues
"--disable-radix-cache",
]
# Log the command clearly
command_str = " ".join(command)
logger.info(f"🚀 Launching SGLang server with command: {command_str}")
process = None # Initialize process to None
try:
# Use Popen to start the server process
# It runs in the foreground relative to this script,
# but allows us to catch KeyboardInterrupt cleanly.
process = subprocess.Popen(command)
# Wait for the process to complete (e.g., user interruption)
process.wait()
# Check return code after waiting
if process.returncode != 0:
logger.error(f"💥 SGLang server process exited with error code: {process.returncode}")
sys.exit(process.returncode)
else:
logger.info("✅ SGLang server process finished gracefully.")
except FileNotFoundError:
logger.error(f"💥 Error: Python executable or sglang module not found.")
logger.error(f"Ensure '{sys.executable}' is correct and sglang is installed.")
sys.exit(1)
except KeyboardInterrupt:
logger.info("🛑 SGLang server launch interrupted by user. Stopping server...")
# Attempt to terminate the process gracefully
if process and process.poll() is None: # Check if process exists and is running
process.terminate()
try:
process.wait(timeout=5) # Wait a bit for termination
logger.info("✅ Server terminated gracefully.")
except subprocess.TimeoutExpired:
logger.warning("⚠️ Server did not terminate gracefully, forcing kill.")
process.kill()
sys.exit(0) # Exit cleanly after interrupt
except Exception as e:
# Catch any other unexpected exceptions during launch or waiting
logger.error(f"🚨 An unexpected error occurred: {e}")
# Ensure process is cleaned up if it exists
if process and process.poll() is None:
process.kill()
sys.exit(1)
if __name__ == "__main__":
# Get context length from config, default to 8192
context_len = MODEL_CONFIG.get("max_seq_length", 8192)
logger.info("----------------------------------------------------")
logger.info("✨ Starting SGLang Generator Server ✨")
logger.info(f" Model ID: {GENERATOR_MODEL_REPO_ID}")
logger.info(f" Port: {GENERATOR_SERVER_PORT}")
logger.info(f" Context Length: {context_len}")
logger.info("----------------------------------------------------")
launch_sglang_server(
model_id=GENERATOR_MODEL_REPO_ID,
port=GENERATOR_SERVER_PORT,
context_length=context_len,
)

@ -1,113 +0,0 @@
import argparse
import asyncio
from collections import deque
from typing import List, Tuple, Union
from fastapi import FastAPI, HTTPException
from flashrag.config import Config
from flashrag.utils import get_retriever
from pydantic import BaseModel
from config import RETRIEVER_SERVER_PORT
app = FastAPI()
retriever_list = []
available_retrievers = deque()
retriever_semaphore = None
def init_retriever(args):
global retriever_semaphore
config = Config(args.config)
for i in range(args.num_retriever):
print(f"Initializing retriever {i + 1}/{args.num_retriever}")
retriever = get_retriever(config)
retriever_list.append(retriever)
available_retrievers.append(i)
# create a semaphore to limit the number of retrievers that can be used concurrently
retriever_semaphore = asyncio.Semaphore(args.num_retriever)
@app.get("/health")
async def health_check():
return {"status": "healthy", "retrievers": {"total": len(retriever_list), "available": len(available_retrievers)}}
class QueryRequest(BaseModel):
query: str
top_n: int = 10
return_score: bool = False
class BatchQueryRequest(BaseModel):
query: List[str]
top_n: int = 10
return_score: bool = False
class Document(BaseModel):
id: str
contents: str
@app.post("/search", response_model=Union[Tuple[List[Document], List[float]], List[Document]])
async def search(request: QueryRequest):
query = request.query
top_n = request.top_n
return_score = request.return_score
if not query or not query.strip():
print(f"Query content cannot be empty: {query}")
raise HTTPException(status_code=400, detail="Query content cannot be empty")
async with retriever_semaphore:
retriever_idx = available_retrievers.popleft()
try:
if return_score:
results, scores = retriever_list[retriever_idx].search(query, top_n, return_score)
return [Document(id=result["id"], contents=result["contents"]) for result in results], scores
else:
results = retriever_list[retriever_idx].search(query, top_n, return_score)
return [Document(id=result["id"], contents=result["contents"]) for result in results]
finally:
available_retrievers.append(retriever_idx)
@app.post("/batch_search", response_model=Union[List[List[Document]], Tuple[List[List[Document]], List[List[float]]]])
async def batch_search(request: BatchQueryRequest):
query = request.query
top_n = request.top_n
return_score = request.return_score
async with retriever_semaphore:
retriever_idx = available_retrievers.popleft()
try:
if return_score:
results, scores = retriever_list[retriever_idx].batch_search(query, top_n, return_score)
return [
[Document(id=result["id"], contents=result["contents"]) for result in results[i]]
for i in range(len(results))
], scores
else:
results = retriever_list[retriever_idx].batch_search(query, top_n, return_score)
return [
[Document(id=result["id"], contents=result["contents"]) for result in results[i]]
for i in range(len(results))
]
finally:
available_retrievers.append(retriever_idx)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="./retriever_config.yaml")
parser.add_argument("--num_retriever", type=int, default=1)
parser.add_argument("--port", type=int, default=RETRIEVER_SERVER_PORT)
args = parser.parse_args()
init_retriever(args)
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=args.port)

@ -1,135 +0,0 @@
import json
import math # Import math for ceiling division
import sys
import traceback # Import traceback
from pathlib import Path
import pandas as pd
# Add project root to Python path if needed (adjust relative path as necessary)
project_root = Path(__file__).resolve().parent.parent
sys.path.append(str(project_root))
from src.embeddings import CustomHuggingFaceEmbeddings
# Import FAISS after potentially adding to sys.path
try:
from langchain_community.vectorstores import FAISS
except ImportError:
print("Error: langchain_community or FAISS not installed. Please install with 'pip install langchain faiss-cpu'")
sys.exit(1)
def build_faiss_index_from_csv(csv_path: str, index_save_path: str, batch_size: int = 128) -> None:
"""Builds a FAISS index from a CSV containing paragraph content and metadata.
Reads a CSV file, generates embeddings for the 'content' column in batches,
and saves the FAISS index files (index.faiss, index.pkl) locally.
Args:
csv_path: Path to the input CSV file (e.g., data/processed/paragraphs.csv).
index_save_path: Path to the directory where the index files should be saved.
batch_size: Number of texts to process in each embedding batch.
"""
print(f"Loading paragraphs from {csv_path}")
try:
df = pd.read_csv(csv_path)
except FileNotFoundError:
print(f"Error: CSV file not found at {csv_path}. Please run the extraction script first.")
return
except Exception as e:
print(f"Error reading CSV file: {e}")
return
if "content" not in df.columns or "metadata" not in df.columns:
print("Error: CSV file must contain 'content' and 'metadata' columns.")
return
if df.empty:
print("Warning: Input CSV file is empty. No index will be built.")
return
# Prepare documents for FAISS
texts = df["content"].astype(str).tolist()
metadatas = []
try:
metadatas = [json.loads(m) for m in df["metadata"].tolist()]
print(f"Prepared {len(texts)} texts and {len(metadatas)} metadatas.")
except json.JSONDecodeError as e:
print(f"Error parsing metadata JSON: {e}. Check the format in {csv_path}")
traceback.print_exc() # Print traceback for JSON errors
return
except Exception as e:
print(f"Error processing metadata: {e}")
traceback.print_exc() # Print traceback for other metadata errors
return
if not texts or not metadatas or len(texts) != len(metadatas):
print(f"Error: Mismatch or empty texts/metadatas. Texts: {len(texts)}, Metadatas: {len(metadatas)}")
return
print("Initializing embeddings model...")
try:
embeddings = CustomHuggingFaceEmbeddings()
except Exception as e:
print(f"Error initializing embeddings model: {e}")
traceback.print_exc()
return
print("Embeddings model initialized successfully.")
vectorstore = None
num_batches = math.ceil(len(texts) / batch_size)
print(f"Processing {len(texts)} texts in {num_batches} batches of size {batch_size}...")
for i in range(num_batches):
start_idx = i * batch_size
end_idx = min((i + 1) * batch_size, len(texts))
batch_texts = texts[start_idx:end_idx]
batch_metadatas = metadatas[start_idx:end_idx]
print(f" Processing batch {i + 1}/{num_batches} (indices {start_idx}-{end_idx - 1})...")
try:
if i == 0:
# Initialize the vector store with the first batch
print(f" Initializing FAISS index with first batch...")
vectorstore = FAISS.from_texts(texts=batch_texts, embedding=embeddings, metadatas=batch_metadatas)
print(" FAISS index initialized.")
else:
# Add subsequent batches to the existing store
if vectorstore is None:
print("Error: vectorstore is None after first batch, cannot add more texts.")
return # Should not happen if first batch succeeded
print(f" Adding batch {i + 1} to FAISS index...")
vectorstore.add_texts(texts=batch_texts, metadatas=batch_metadatas)
print(f" Batch {i + 1} added.")
except Exception as e:
print(f"Error processing batch {i + 1} (indices {start_idx}-{end_idx - 1}): {e}")
traceback.print_exc()
print("Stopping index creation due to error in batch processing.")
return # Exit if any batch fails
if vectorstore is None:
print("Error: Failed to create or add any data to the vectorstore.")
return
# Save the completed index
try:
print(f"Attempting to save final FAISS index files to directory: {index_save_path}")
# Ensure the target directory exists before saving
Path(index_save_path).mkdir(parents=True, exist_ok=True)
vectorstore.save_local(index_save_path)
print(f"Successfully saved final FAISS index files (index.faiss, index.pkl) to: {index_save_path}")
except Exception as e:
print(f"Error during final vectorstore.save_local to {index_save_path}: {e}")
traceback.print_exc()
if __name__ == "__main__":
# Define paths relative to this script or use absolute paths
PROCESSED_DIR = Path("data/processed")
INPUT_CSV = str(PROCESSED_DIR / "paragraphs.csv")
# FAISS save_local will save index.faiss and index.pkl in this directory
INDEX_SAVE_DIR = str(PROCESSED_DIR) # Save directly to processed dir
build_faiss_index_from_csv(INPUT_CSV, INDEX_SAVE_DIR, batch_size=128)

@ -1,30 +0,0 @@
#!/bin/bash
# This script is taken from https://github.com/StonyBrookNLP/musique with slight modifications
set -e
set -x
# If gdown doesn't work, you can download files from mentioned URLs manually
# and put them at appropriate locations.
pip install gdown
ZIP_NAME="musique_v1.0.zip"
# URL: https://drive.google.com/file/d/1tGdADlNjWFaHLeZZGShh2IRcpO6Lv24h/view?usp=sharing
gdown --id 1tGdADlNjWFaHLeZZGShh2IRcpO6Lv24h --output $ZIP_NAME
TARGET_DIR="./data/raw"
mkdir -p $TARGET_DIR
unzip -o $(basename $ZIP_NAME) -d $TARGET_DIR # Extract directly into target
# Move contents from the extracted 'data' folder up one level
mv $TARGET_DIR/data/* $TARGET_DIR/
# Clean up the empty directory and the zip
rm -rf $TARGET_DIR/data
rm $ZIP_NAME
# TODO: prevent these from zipping in.
rm -rf __MACOSX
# Clean up potential extracted .DS_Store
rm -f $TARGET_DIR/.DS_Store

@ -1,101 +0,0 @@
import json
import sys
from collections import defaultdict # Use defaultdict for cleaner accumulation
from pathlib import Path
import pandas as pd
# Add project root to Python path if needed (adjust relative path as necessary)
# project_root = Path(__file__).resolve().parent.parent
# sys.path.append(str(project_root))
# from config import logger # Assuming you have a logger setup
def extract_unique_paragraphs(input_paths: list[str], output_csv_path: str) -> None:
"""Extracts unique paragraphs from specified JSONL files.
Reads Musique JSONL files (train, dev, test), finds unique paragraphs
(regardless of is_supporting flag), combines title and text,
tracks source question IDs, and saves to CSV.
Args:
input_paths: A list of paths to the input JSONL files.
output_csv_path: Path to save the output CSV file.
"""
output_dir = Path(output_csv_path).parent
output_dir.mkdir(parents=True, exist_ok=True)
# Use paragraph content as key, value is the set of source question IDs
paragraphs_data = defaultdict(set)
print("Starting paragraph extraction (including non-supporting)...")
for file_path in input_paths:
print(f"Processing file: {file_path}")
try:
with open(file_path, "r", encoding="utf-8") as infile:
for line_num, line in enumerate(infile, 1):
try:
data = json.loads(line)
main_question_id = data.get("id")
if not main_question_id:
print(f"Warning: Missing 'id' in line {line_num} of {file_path}")
continue
for p in data.get("paragraphs", []):
title = p.get("title", "No Title")
text = p.get("paragraph_text", "")
content = f"{title}\n{text}".strip()
if not content:
continue # Skip empty paragraphs
paragraphs_data[content].add(main_question_id)
except json.JSONDecodeError:
print(f"Warning: Skipping invalid JSON in line {line_num} of {file_path}")
except Exception as e:
print(f"Warning: Error processing line {line_num} in {file_path}: {e}")
except FileNotFoundError:
print(f"Error: Input file not found: {file_path}")
except Exception as e:
print(f"Error reading file {file_path}: {e}")
print(f"Found {len(paragraphs_data)} unique paragraphs (supporting and non-supporting).")
# Prepare data for DataFrame
output_list = []
sorted_content = sorted(paragraphs_data.keys())
for chunk_id, content in enumerate(sorted_content, 1):
question_ids = paragraphs_data[content]
metadata = {"source_question_ids": sorted(list(question_ids))}
output_list.append(
{
"chunk_id": chunk_id,
"content": content,
"metadata": json.dumps(metadata), # Store metadata as JSON string
}
)
if not output_list:
print("No paragraphs found to save.")
return
df = pd.DataFrame(output_list)
try:
df.to_csv(output_csv_path, index=False)
print(f"Successfully saved unique paragraphs to {output_csv_path}")
except Exception as e:
print(f"Error saving CSV file: {e}")
if __name__ == "__main__":
RAW_DIR = Path("data/raw")
PROCESSED_DIR = Path("data/processed")
input_files = [
str(RAW_DIR / "musique_ans_v1.0_train.jsonl"),
str(RAW_DIR / "musique_ans_v1.0_dev.jsonl"),
str(RAW_DIR / "musique_ans_v1.0_test.jsonl"),
]
output_csv = str(PROCESSED_DIR / "paragraphs.csv")
extract_unique_paragraphs(input_files, output_csv)

@ -1,155 +0,0 @@
"""Prepares a deterministic sampled dev set (questions_dev.jsonl) from raw Musique dev data."""
import json
import math
import os
import re
from collections import defaultdict
from pathlib import Path
def transform_musique_dev_data(input_path: str, output_path: str, sample_config: dict) -> None:
"""Transforms Musique dev data with deterministic stratified sampling using uniform selection from sorted lists.
Reads dev data, categorizes by hop type (2, 3, 4), sorts categories by ID,
selects N samples uniformly spaced from each sorted category based on sample_config,
combines, sorts final list by ID, combines answers/aliases, extracts supporting paras,
and writes the transformed data to output_path.
Args:
input_path: Path to the input JSONL file (e.g., data/raw/musique_ans_v1.0_dev.jsonl).
output_path: Path to the output JSONL file (e.g., data/processed/questions_dev.jsonl).
sample_config: Dictionary specifying samples per hop type (e.g., {"2hop": 20, "3hop": 15, "4hop": 15}).
"""
output_dir = Path(output_path).parent
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Reading all data from {input_path} for dev sampling...")
all_data = []
try:
with open(input_path, "r", encoding="utf-8") as infile:
for line_num, line in enumerate(infile, 1):
try:
data = json.loads(line)
if "id" in data:
all_data.append(data)
else:
print(f"Warning: Skipping line {line_num} due to missing 'id' field in {input_path}")
except json.JSONDecodeError:
print(f"Warning: Skipping invalid JSON in line {line_num} of {input_path}")
except FileNotFoundError:
print(f"Error: Input file not found at {input_path}")
return
except Exception as e:
print(f"Error reading file {input_path}: {e}")
return
print(f"Read {len(all_data)} total samples from dev set.")
# Categorize data by hop count (2hop, 3hop, 4hop)
categorized_data = defaultdict(list)
print("Categorizing data by hop type (2, 3, 4)...")
for data in all_data:
q_id = data["id"]
hop_type = None
if q_id.startswith("2hop"):
hop_type = "2hop"
elif q_id.startswith("3hop"):
hop_type = "3hop"
elif q_id.startswith("4hop"):
hop_type = "4hop"
if hop_type:
categorized_data[hop_type].append(data)
# Deterministic sampling using sorting and uniform index selection
final_sample_list = []
total_target = sum(sample_config.values())
print(f"Sampling deterministically via uniform selection from sorted lists to get {total_target} dev samples...")
for hop_type, target_count in sample_config.items():
available_samples = categorized_data.get(hop_type, [])
current_count = len(available_samples)
print(f" {hop_type}: Found {current_count} samples, need {target_count}.")
if current_count == 0:
continue
available_samples.sort(key=lambda x: x["id"])
selected_samples_for_hop = []
if current_count < target_count:
print(f" Warning: Not enough samples for {hop_type}. Taking all {current_count} sorted samples.")
selected_samples_for_hop = available_samples
elif target_count > 0: # Ensure target_count is positive before selecting
print(f" Selecting {target_count} samples uniformly from {current_count}...")
# Calculate indices using integer interpretation of evenly spaced points
indices_to_take = [
int(i * (current_count - 1) / (target_count - 1)) if target_count > 1 else 0
for i in range(target_count)
] # Adjust index calc for edges
indices_to_take = sorted(list(set(indices_to_take))) # Ensure unique indices
# Simple fallback if uniqueness reduced count below target
while len(indices_to_take) < target_count and len(indices_to_take) < current_count:
next_val = indices_to_take[-1] + 1
if next_val < current_count:
indices_to_take.append(next_val)
else: # Cannot add more unique indices
break
selected_samples_for_hop = [
available_samples[idx] for idx in indices_to_take[:target_count]
] # Select based on unique indices, capped at target
final_sample_list.extend(selected_samples_for_hop)
print(f"Selected {len(final_sample_list)} dev samples in total.")
# Sort the final combined list by ID for consistent output order
print("Sorting the final combined dev sample list by ID...")
final_sample_list.sort(key=lambda x: x["id"])
# Process and write the selected samples
print(f"Processing and writing {len(final_sample_list)} selected dev samples to {output_path}...")
count = 0
try:
with open(output_path, "w", encoding="utf-8") as outfile:
for data in final_sample_list:
try:
supporting_paragraphs = [
p["paragraph_text"] for p in data.get("paragraphs", []) if p.get("is_supporting", False)
]
main_answer = data.get("answer", "")
aliases = data.get("answer_aliases", [])
all_answers = [main_answer] + (aliases if isinstance(aliases, list) else [])
valid_answers = [str(ans).strip() for ans in all_answers if ans and str(ans).strip()]
unique_valid_answers = list(set(valid_answers)) # Keep unique, don't sort alphabetically
combined_answer_str = " OR ".join(unique_valid_answers)
output_data = {
"id": data.get("id"),
"question": data.get("question"),
"answer": combined_answer_str,
"supporting_paragraphs": supporting_paragraphs,
}
outfile.write(json.dumps(output_data) + "\n")
count += 1
except KeyError as e:
print(f"Skipping sample due to missing key {e}: {data.get('id')}")
print(f"Successfully processed and wrote {count} dev samples.")
except Exception as e:
print(f"An unexpected error occurred during writing: {e}")
if __name__ == "__main__":
# Define file paths relative to the project root
# Ensure this script is run from the project root or adjust paths accordingly
RAW_DIR = Path("data/raw")
PROCESSED_DIR = Path("data/processed")
# Define sampling configuration for the dev set
DEV_SAMPLING_CONFIG = {"2hop": 20, "3hop": 15, "4hop": 15} # Total = 50
INPUT_FILE = RAW_DIR / "musique_ans_v1.0_dev.jsonl"
OUTPUT_FILE = PROCESSED_DIR / "questions_dev.jsonl"
transform_musique_dev_data(str(INPUT_FILE), str(OUTPUT_FILE), DEV_SAMPLING_CONFIG)
print(f"\nMusique DEV JSONL transformation and deterministic sampling complete.")

@ -1,172 +0,0 @@
import json
import math # Keep math import
import os
import re # Import re for parsing ID
from collections import defaultdict
from pathlib import Path
# import random # No longer needed
# SEED = 42 # No longer needed
# random.seed(SEED) # No longer needed
def transform_musique_data(input_path: str, output_path: str, sample_config: dict) -> None:
"""Transforms Musique data with deterministic stratified sampling using uniform selection from sorted lists.
Reads data, categorizes by detailed hop type, sorts categories by ID,
selects N samples uniformly spaced from each sorted category,
combines, sorts final list by ID, and writes to output.
Args:
input_path: Path to the input JSONL file.
output_path: Path to the output JSONL file.
sample_config: Dictionary specifying samples per detailed hop type (e.g., {"2hop": 400, "3hop1": 150, ...}).
"""
output_dir = Path(output_path).parent
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Reading all data from {input_path} for sampling...")
all_data = []
try:
with open(input_path, "r", encoding="utf-8") as infile:
for line_num, line in enumerate(infile, 1):
try:
data = json.loads(line)
if "id" in data:
all_data.append(data)
else:
print(f"Warning: Skipping line {line_num} due to missing 'id' field in {input_path}")
except json.JSONDecodeError:
print(f"Warning: Skipping invalid JSON in line {line_num} of {input_path}")
except FileNotFoundError:
print(f"Error: Input file not found at {input_path}")
return
except Exception as e:
print(f"Error reading file {input_path}: {e}")
return
print(f"Read {len(all_data)} total samples with IDs.")
# Detailed Categorization by hop type
categorized_data = defaultdict(list)
print("Categorizing data by detailed hop type (e.g., 3hop1, 4hop2)...")
for data in all_data:
q_id = data["id"]
match = re.match(r"^(2hop|3hop[12]|4hop[123])__", q_id)
if match:
detailed_hop_type = match.group(1)
categorized_data[detailed_hop_type].append(data)
# else: # Optional: log if an ID doesn't match expected pattern
# print(f"Warning: ID {q_id} does not match expected hop pattern.")
# Deterministic sampling using sorting and uniform index selection
final_sample_list = []
total_target = sum(sample_config.values())
print(f"Sampling deterministically via uniform selection from sorted lists to get {total_target} samples...")
# Check if all requested hop types exist in config
for hop_type in sample_config.keys():
if hop_type not in categorized_data:
print(f"Warning: Hop type '{hop_type}' requested in config but not found in data.")
for hop_type, target_count in sample_config.items():
available_samples = categorized_data.get(hop_type, [])
current_count = len(available_samples)
print(f" {hop_type}: Found {current_count} samples, need {target_count}.")
if current_count == 0:
continue
# Sort the list for this category by ID
available_samples.sort(key=lambda x: x["id"])
selected_samples_for_hop = []
if current_count < target_count:
print(f" Warning: Not enough samples for {hop_type}. Taking all {current_count} sorted samples.")
selected_samples_for_hop = available_samples
else:
# Select target_count indices spread uniformly across the available samples
print(f" Selecting {target_count} samples uniformly from {current_count}...")
# Calculate indices using integer interpretation of evenly spaced points
indices_to_take = [int(i * current_count / target_count) for i in range(target_count)]
# Ensure uniqueness in case of rounding issues with small numbers (though unlikely here)
indices_to_take = sorted(list(set(indices_to_take)))
# Adjust if rounding resulted in fewer than target_count unique indices
while len(indices_to_take) < target_count:
# This is a fallback, shouldn't happen if current_count >= target_count
# Add indices from the end if needed, avoiding duplicates
next_idx = indices_to_take[-1] + 1
if next_idx < current_count and next_idx not in indices_to_take:
indices_to_take.append(next_idx)
else: # Should not be reachable if logic is sound
break
# Select samples at the calculated indices
selected_samples_for_hop = [
available_samples[idx] for idx in indices_to_take[:target_count]
] # Ensure we take exactly target_count
final_sample_list.extend(selected_samples_for_hop)
print(f"Selected {len(final_sample_list)} samples in total.")
# Sort the final combined list by ID for consistent output order
print("Sorting the final combined sample list by ID...")
final_sample_list.sort(key=lambda x: x["id"])
# Process and write the selected samples
print(f"Processing and writing {len(final_sample_list)} selected samples to {output_path}...")
count = 0
try:
with open(output_path, "w", encoding="utf-8") as outfile:
for data in final_sample_list:
try:
supporting_paragraphs = [
p["paragraph_text"] for p in data.get("paragraphs", []) if p.get("is_supporting", False)
]
main_answer = data.get("answer", "")
aliases = data.get("answer_aliases", [])
all_answers = [main_answer] + (aliases if isinstance(aliases, list) else [])
valid_answers = [str(ans).strip() for ans in all_answers if ans and str(ans).strip()]
unique_valid_answers = list(set(valid_answers))
combined_answer_str = " OR ".join(unique_valid_answers)
output_data = {
"id": data.get("id"),
"question": data.get("question"),
"answer": combined_answer_str,
"supporting_paragraphs": supporting_paragraphs,
}
outfile.write(json.dumps(output_data) + "\n")
count += 1
except KeyError as e:
print(f"Skipping sample due to missing key {e}: {data.get('id')}")
print(f"Successfully processed and wrote {count} samples.")
except Exception as e:
print(f"An unexpected error occurred during writing: {e}")
if __name__ == "__main__":
# Define file paths
RAW_DIR = Path("data/raw")
PROCESSED_DIR = Path("data/processed")
# Define detailed sampling configuration
SAMPLING_CONFIG = {
"2hop": 400,
"3hop1": 150,
"3hop2": 150,
"4hop1": 100,
"4hop2": 100,
"4hop3": 100,
} # Total = 1000
transform_musique_data(
str(RAW_DIR / "musique_ans_v1.0_train.jsonl"), str(PROCESSED_DIR / "questions.jsonl"), SAMPLING_CONFIG
)
print(
"\nMusique JSONL transformation and detailed deterministic sampling (uniform selection from sorted) complete."
)
# Note: Dev/Test files are not processed by default with this sampling logic.

@ -1 +0,0 @@
Subproject commit 7e60ab26825a452f8ee8eb19799d1cb6c1746326
Loading…
Cancel
Save