parent
5eabd121a3
commit
89e07bc02d
@ -1,2 +1,3 @@
|
||||
HF_TOKEN=<your-huggingface-token>
|
||||
OPENROUTER_API_KEY=<your-openrouter-api-key>
|
||||
TAVILY_API_KEY=<your-tavily-api-key>
|
||||
SERPER_API_KEY=<your-serper-api-key>
|
@ -1,3 +0,0 @@
|
||||
[submodule "third_party/FlashRAG"]
|
||||
path = third_party/FlashRAG
|
||||
url = https://github.com/RUC-NLPIR/FlashRAG.git
|
@ -1,656 +0,0 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from flashrag.config import Config
|
||||
from flashrag.generator.generator import BaseGenerator
|
||||
from flashrag.pipeline import BasicPipeline
|
||||
from flashrag.retriever.retriever import BaseTextRetriever
|
||||
from flashrag.utils import get_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from config import logger
|
||||
from src.agent import Agent, AgenticOutputs
|
||||
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
|
||||
|
||||
|
||||
def retry(max_retries=10, sleep=1):
|
||||
"""Decorator to retry a function with exponential backoff."""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
func_name = func.__name__
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"Attempt {attempt + 1} of {func_name} failed: {e}")
|
||||
if attempt == max_retries - 1:
|
||||
logger.error(f"Function {func_name} failed after {max_retries} retries.", exc_info=True)
|
||||
raise e
|
||||
backoff_time = sleep * (2**attempt)
|
||||
logger.info(f"Retrying {func_name} in {backoff_time:.2f} seconds...")
|
||||
time.sleep(backoff_time)
|
||||
logger.error(f"Function {func_name} retry logic finished unexpectedly.")
|
||||
return None
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class RemoteRetriever(BaseTextRetriever):
|
||||
"""A wrapper for remote retriever service with retry logic and logging."""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
"""Initializes the RemoteRetriever."""
|
||||
super().__init__(config)
|
||||
self.remote_url = f"http://{getattr(config, 'remote_retriever_url', 'localhost:8001')}"
|
||||
self.topk = getattr(config, "retriever_topk", 5)
|
||||
logger.info(f"🔗 Remote retriever URL: {self.remote_url}")
|
||||
|
||||
@retry(max_retries=3, sleep=2)
|
||||
def _search(self, query: str, num: int | None = None, return_score: bool = False) -> list[dict]:
|
||||
"""Search for documents using the remote retriever service."""
|
||||
num = num if num is not None else self.topk
|
||||
url = f"{self.remote_url}/search"
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
json={"query": query, "top_n": num, "return_score": return_score},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
results = response.json()
|
||||
return results
|
||||
except requests.exceptions.Timeout:
|
||||
logger.error(f"Search request timed out after 30 seconds for query: {query[:50]}...")
|
||||
raise
|
||||
except requests.exceptions.ConnectionError:
|
||||
logger.error(f"Could not connect to search service at {url}")
|
||||
raise
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Search request failed: {e}", exc_info=True)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected search error: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
@retry(max_retries=3, sleep=2)
|
||||
def _batch_search(
|
||||
self, queries: list[str], num: int | None = None, return_score: bool = False
|
||||
) -> list[list[dict]]:
|
||||
"""Batch search for documents using the remote retriever service."""
|
||||
num = num if num is not None else self.topk
|
||||
url = f"{self.remote_url}/batch_search"
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
json={"query": queries, "top_n": num, "return_score": return_score},
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
return results
|
||||
except requests.exceptions.Timeout:
|
||||
logger.error(f"Batch search request timed out after 60 seconds for {len(queries)} queries.")
|
||||
raise
|
||||
except requests.exceptions.ConnectionError:
|
||||
logger.error(f"Could not connect to batch search service at {url}")
|
||||
raise
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Batch search request failed: {e}", exc_info=True)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected batch search error: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
class ReSearchPipeline(BasicPipeline):
|
||||
"""Pipeline for ReSearch method using Agent for generation and tool use."""
|
||||
|
||||
def __init__(
|
||||
self, config: Config, retriever: BaseTextRetriever | None = None, generator: BaseGenerator | None = None
|
||||
):
|
||||
"""Initializes the ReSearchPipeline."""
|
||||
super().__init__(config)
|
||||
logger.info("🔧 Initializing ReSearchPipeline...")
|
||||
|
||||
self.retriever = retriever or RemoteRetriever(config)
|
||||
|
||||
self.generator = generator or SGLRemoteGenerator(config)
|
||||
|
||||
try:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(config.generator_model_path, trust_remote_code=True)
|
||||
if not self.tokenizer.pad_token:
|
||||
logger.warning("Tokenizer does not have a pad token; setting to eos_token.")
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
self.tokenizer.padding_side = "left"
|
||||
logger.info("✅ Tokenizer initialized.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize tokenizer: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
tokenizer_name = self.tokenizer.name_or_path.lower()
|
||||
|
||||
if "deepseek-ai/deepseek-r1-distill" in tokenizer_name:
|
||||
adapter = R1DistilTokenizerAdapter()
|
||||
elif "llama" in tokenizer_name:
|
||||
adapter = LlamaTokenizerAdapter()
|
||||
else:
|
||||
logger.warning(f"Unknown tokenizer type '{tokenizer_name}', defaulting to R1DistilTokenizerAdapter.")
|
||||
adapter = R1DistilTokenizerAdapter()
|
||||
logger.info(f"🔩 Using Tokenizer Adapter: {type(adapter).__name__}")
|
||||
|
||||
def retriever_search(query: str, return_type=str, results: int = 5):
|
||||
try:
|
||||
search_results = self.retriever._search(query, num=results)
|
||||
return self.format_search_results(search_results)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during agent's retriever search for query '{query[:50]}...': {e}", exc_info=True)
|
||||
return "<information>Search failed due to an internal error."
|
||||
|
||||
self.agent = Agent(adapter, search_fn=retriever_search)
|
||||
logger.info("✅ Agent initialized.")
|
||||
logger.info("✅ ReSearchPipeline initialized successfully.")
|
||||
|
||||
def format_search_results(self, search_results: list[dict]) -> str:
|
||||
"""Formats search results into a string for the agent prompt."""
|
||||
if not search_results:
|
||||
return "<information>No results found.</information>"
|
||||
max_content_len = 500
|
||||
formatted = "\n-------\n".join(
|
||||
[
|
||||
f"Result {i + 1}: {r.get('contents', 'N/A')[:max_content_len]}{'...' if len(r.get('contents', '')) > max_content_len else ''}"
|
||||
for i, r in enumerate(search_results)
|
||||
]
|
||||
)
|
||||
formatted_str = f"<information>{formatted}</information>"
|
||||
|
||||
return formatted_str
|
||||
|
||||
def extract_search_query(self, text: str) -> str | None:
|
||||
"""Extract search query from text between <search> tags."""
|
||||
pattern = re.compile(r"<search>(.*?)</search>", re.DOTALL)
|
||||
matches = pattern.findall(text)
|
||||
if matches:
|
||||
query = matches[-1].strip()
|
||||
return query
|
||||
return None
|
||||
|
||||
def extract_answer(self, text: str) -> str | None:
|
||||
"""Extract answer from text between <answer> tags."""
|
||||
pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL)
|
||||
matches = pattern.findall(text)
|
||||
if matches:
|
||||
answer = matches[-1].strip()
|
||||
|
||||
return answer
|
||||
|
||||
return None
|
||||
|
||||
def run(self, dataset, do_eval: bool = True, pred_process_fun=None):
|
||||
"""Runs the ReSearch pipeline on the dataset using the Agent."""
|
||||
logger.info(f"🏃 Starting ReSearch pipeline run with {len(dataset)} items...")
|
||||
|
||||
try:
|
||||
questions = [item.question if hasattr(item, "question") else item["question"] for item in dataset]
|
||||
|
||||
except (KeyError, AttributeError, TypeError) as e:
|
||||
logger.error(f"Failed to extract questions from dataset items. Error: {e}", exc_info=True)
|
||||
logger.error("Ensure dataset items have a 'question' key or attribute.")
|
||||
return dataset
|
||||
|
||||
agent_max_generations = getattr(self.config, "agent_max_generations", 32)
|
||||
generator_max_output_len = getattr(self.config, "generator_max_output_len", 24576)
|
||||
|
||||
try:
|
||||
logger.info(f"🤖 Running agent inference for {len(questions)} questions...")
|
||||
agent_outputs: AgenticOutputs = self.agent.run_agent(
|
||||
generate_fn=self.generator.generate,
|
||||
tokenizer=self.tokenizer,
|
||||
questions=questions,
|
||||
max_generations=agent_max_generations,
|
||||
max_new_tokens=generator_max_output_len,
|
||||
)
|
||||
final_responses = agent_outputs.final_response_str
|
||||
logger.info(f"✅ Agent inference completed. Received {len(final_responses)} final responses.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent run failed during inference: {e}", exc_info=True)
|
||||
logger.warning("Agent run failed, attempting evaluation with potentially incomplete results.")
|
||||
for item in dataset:
|
||||
if hasattr(item, "update_output"):
|
||||
item.update_output("pred", "AGENT_ERROR")
|
||||
elif isinstance(item, dict):
|
||||
item["pred"] = "AGENT_ERROR"
|
||||
|
||||
logger.info("📝 Extracting answers and updating dataset items...")
|
||||
num_updated = 0
|
||||
num_missing_answers = 0
|
||||
if len(final_responses) == len(dataset):
|
||||
for i, item in enumerate(dataset):
|
||||
response = final_responses[i]
|
||||
answer = self.extract_answer(response)
|
||||
pred_to_save = answer if answer is not None else ""
|
||||
|
||||
if answer is None:
|
||||
num_missing_answers += 1
|
||||
|
||||
if hasattr(item, "update_output"):
|
||||
item.update_output("pred", pred_to_save)
|
||||
item.update_output("final_response", response)
|
||||
num_updated += 1
|
||||
elif isinstance(item, dict):
|
||||
item["pred"] = pred_to_save
|
||||
item["final_response"] = response
|
||||
num_updated += 1
|
||||
else:
|
||||
logger.warning(f"Item {i} has unknown type {type(item)}, cannot update with prediction.")
|
||||
|
||||
logger.info(f"Updated {num_updated}/{len(dataset)} dataset items with predictions.")
|
||||
if num_missing_answers > 0:
|
||||
logger.warning(f"{num_missing_answers} items had no <answer> tag.")
|
||||
else:
|
||||
logger.error(
|
||||
f"Mismatch between dataset size ({len(dataset)}) and number of agent responses ({len(final_responses)}). Cannot reliably update dataset."
|
||||
)
|
||||
for item in dataset:
|
||||
if hasattr(item, "update_output"):
|
||||
item.update_output("pred", "RESPONSE_COUNT_MISMATCH")
|
||||
elif isinstance(item, dict):
|
||||
item["pred"] = "RESPONSE_COUNT_MISMATCH"
|
||||
|
||||
if do_eval:
|
||||
logger.info("📊 Evaluating results using BasicPipeline.evaluate...")
|
||||
try:
|
||||
dataset = self.evaluate(dataset, do_eval=True, pred_process_fun=pred_process_fun)
|
||||
logger.info("✅ Evaluation completed via base class method.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during BasicPipeline.evaluate: {e}", exc_info=True)
|
||||
logger.warning("Evaluation may be incomplete.")
|
||||
else:
|
||||
logger.info("Skipping evaluation step as do_eval=False.")
|
||||
|
||||
logger.info("✅ ReSearch pipeline run finished.")
|
||||
return dataset
|
||||
|
||||
|
||||
class SGLRemoteGenerator(BaseGenerator):
|
||||
"""Class for decoder-only generator, based on SGLang remote service."""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
"""Initializes the SGLRemoteGenerator."""
|
||||
super().__init__(config)
|
||||
logger.info("🔧 Initializing SGLRemoteGenerator...")
|
||||
sgl_url = getattr(config, "sgl_remote_url", "localhost:8002")
|
||||
self.sgl_remote_url = f"http://{sgl_url}/generate"
|
||||
self.health_check_url = f"http://{sgl_url}/health"
|
||||
logger.info(f"🔗 Remote Generator URL: {self.sgl_remote_url}")
|
||||
self.model_path = getattr(config, "generator_model_path", None)
|
||||
if not self.model_path:
|
||||
logger.error("generator_model_path not found in config!")
|
||||
raise ValueError("generator_model_path is required for SGLRemoteGenerator")
|
||||
|
||||
try:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
|
||||
logger.info("✅ Tokenizer loaded for generator.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load tokenizer for generator from {self.model_path}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
self.generation_params = getattr(config, "generation_params", {})
|
||||
self.config = config
|
||||
|
||||
self._check_health()
|
||||
|
||||
def _check_health(self):
|
||||
"""Checks the health of the remote generator service."""
|
||||
try:
|
||||
test_response = requests.get(self.health_check_url, timeout=5)
|
||||
test_response.raise_for_status()
|
||||
logger.info("✅ Remote generator service is available")
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Could not connect or verify remote generator service at {self.health_check_url}: {str(e)}")
|
||||
logger.warning("Please ensure the SGLang service is running and accessible.")
|
||||
|
||||
@retry(max_retries=5, sleep=2)
|
||||
def generate(
|
||||
self,
|
||||
input_list: list[str] | str,
|
||||
return_raw_output: bool = False,
|
||||
return_scores: bool = False,
|
||||
**params,
|
||||
) -> list[str] | tuple[list[str], list[list[float]]] | list[dict]:
|
||||
"""Generates text using the remote SGLang service."""
|
||||
if isinstance(input_list, str):
|
||||
input_list = [input_list]
|
||||
if not isinstance(input_list, list) or not all(isinstance(item, str) for item in input_list):
|
||||
raise ValueError("Input must be a string or a list of strings.")
|
||||
|
||||
batch_size = len(input_list)
|
||||
data_to_remote = {"text": input_list}
|
||||
|
||||
effective_params = deepcopy(self.generation_params)
|
||||
effective_params.update(params)
|
||||
|
||||
curr_sampling_params = {}
|
||||
if effective_params.get("do_sample", True) is False:
|
||||
curr_sampling_params["temperature"] = 0.0
|
||||
else:
|
||||
curr_sampling_params["temperature"] = effective_params.get(
|
||||
"temperature", getattr(self.config, "temperature", 0.7)
|
||||
)
|
||||
|
||||
default_max_tokens = getattr(self.config, "generator_max_output_len", 1024)
|
||||
curr_sampling_params["max_new_tokens"] = effective_params.get("max_new_tokens", default_max_tokens)
|
||||
|
||||
stop_sequences = effective_params.get("stop", [])
|
||||
if isinstance(stop_sequences, str):
|
||||
stop_sequences = [stop_sequences]
|
||||
if stop_sequences:
|
||||
curr_sampling_params["stop"] = stop_sequences
|
||||
|
||||
keys_to_remove = ["do_sample", "temperature", "max_new_tokens", "stop"]
|
||||
for key in keys_to_remove:
|
||||
effective_params.pop(key, None)
|
||||
|
||||
if "top_p" in effective_params:
|
||||
curr_sampling_params["top_p"] = effective_params["top_p"]
|
||||
if "top_k" in effective_params:
|
||||
curr_sampling_params["top_k"] = effective_params["top_k"]
|
||||
|
||||
data_to_remote["sampling_params"] = curr_sampling_params
|
||||
|
||||
if return_scores:
|
||||
data_to_remote["return_logprob"] = True
|
||||
data_to_remote["top_logprobs_num"] = getattr(self.config, "top_logprobs_num", 2)
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
self.sgl_remote_url, json=data_to_remote, timeout=120, headers={"Content-Type": "application/json"}
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
response_list = response.json()
|
||||
|
||||
if return_raw_output:
|
||||
return response_list
|
||||
|
||||
generated_text = []
|
||||
for item in response_list:
|
||||
text = item.get("text", "")
|
||||
finish_reason = item.get("meta_info", {}).get("finish_reason", {})
|
||||
matched_stop = finish_reason.get("matched")
|
||||
if matched_stop and curr_sampling_params.get("stop") and matched_stop in curr_sampling_params["stop"]:
|
||||
text += matched_stop
|
||||
generated_text.append(text)
|
||||
|
||||
if return_scores:
|
||||
scores = []
|
||||
for resp_item in response_list:
|
||||
logprobs_list = resp_item.get("meta_info", {}).get("output_token_logprobs", [])
|
||||
token_scores = [
|
||||
np.exp(logprob[0]) if (logprob and len(logprob) > 0) else 0.0 for logprob in logprobs_list
|
||||
]
|
||||
scores.append(token_scores)
|
||||
return generated_text, scores
|
||||
else:
|
||||
return generated_text
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
logger.error("Generation request timed out after 120 seconds.")
|
||||
raise
|
||||
except requests.exceptions.ConnectionError:
|
||||
logger.error(f"Could not connect to remote generator service at {self.sgl_remote_url}.")
|
||||
raise
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Network error during generation: {str(e)}", exc_info=True)
|
||||
raise
|
||||
except json.JSONDecodeError:
|
||||
response_text = "Unknown (error occurred before response object assignment)"
|
||||
if "response" in locals() and hasattr(response, "text"):
|
||||
response_text = response.text[:500]
|
||||
logger.error(
|
||||
f"Failed to decode JSON response from {self.sgl_remote_url}. Response text: {response_text}...",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during generation: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def load_dataset_items(config: Config, split: str) -> list[dict | object]:
|
||||
"""Loads dataset items using flashrag's get_dataset."""
|
||||
logger.info(f"📚 Loading dataset: {config.dataset_name}, Split: {split}")
|
||||
try:
|
||||
all_splits = get_dataset(config)
|
||||
if split not in all_splits:
|
||||
logger.error(
|
||||
f"Split '{split}' not found in dataset '{config.dataset_name}'. Available splits: {list(all_splits.keys())}"
|
||||
)
|
||||
return []
|
||||
dataset_items = all_splits[split]
|
||||
logger.info(f"Successfully loaded {len(dataset_items)} items for split '{split}'.")
|
||||
|
||||
return dataset_items
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
f"Dataset files not found for '{config.dataset_name}' in '{config.data_dir}'. Check config and paths."
|
||||
)
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading dataset using get_dataset: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
def save_results(args: argparse.Namespace, config: Config, result_dataset, run_duration: float):
|
||||
"""Saves summary and debug information."""
|
||||
logger.info("💾 Saving results...")
|
||||
summary_file = os.path.join(args.save_dir, f"{args.save_note}_summary.txt")
|
||||
debug_file = os.path.join(args.save_dir, f"{args.save_note}_debug.json")
|
||||
|
||||
num_items = len(result_dataset)
|
||||
|
||||
logger.info(f"Saving summary results to {summary_file}...")
|
||||
try:
|
||||
with open(summary_file, "w", encoding="utf-8") as f:
|
||||
f.write("EVALUATION SUMMARY\n")
|
||||
f.write("=================\n\n")
|
||||
f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||
f.write(f"Run Duration: {run_duration:.2f} seconds\n")
|
||||
f.write(f"Dataset: {config.dataset_name} ({args.split} split)\n")
|
||||
f.write(f"Model: {config.generator_model_path}\n")
|
||||
f.write(f"Retriever: {config.remote_retriever_url}\n")
|
||||
f.write(f"Agent Max Generations: {getattr(config, 'agent_max_generations', 'N/A')}\n")
|
||||
f.write(f"Generator Max Output Len: {getattr(config, 'generator_max_output_len', 'N/A')}\n\n")
|
||||
f.write(f"Total items processed: {num_items}\n")
|
||||
f.write("\nNote: Verification was skipped in this run.\n")
|
||||
f.write("Note: Overall metrics (like EM, F1) are usually printed to console by evaluate method.\n")
|
||||
|
||||
logger.info(f"✅ Summary saved to {summary_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving summary file '{summary_file}': {e}", exc_info=True)
|
||||
|
||||
logger.info(f"Saving debug information (predictions & responses) to {debug_file}...")
|
||||
try:
|
||||
debug_data = []
|
||||
for i, item in enumerate(result_dataset):
|
||||
item_data: dict[str, object] = {}
|
||||
|
||||
def get_item_value(data_item, key_or_attr: str) -> str | int | float | list | bool | None:
|
||||
if isinstance(data_item, dict):
|
||||
return data_item.get(key_or_attr)
|
||||
elif hasattr(data_item, key_or_attr):
|
||||
return getattr(data_item, key_or_attr)
|
||||
return None
|
||||
|
||||
item_data["item_index"] = i
|
||||
item_data["question"] = get_item_value(item, "question")
|
||||
item_data["prediction"] = get_item_value(item, "pred")
|
||||
item_data["final_response"] = get_item_value(item, "final_response")
|
||||
|
||||
gt_answer_val = None
|
||||
try:
|
||||
gt_answer_val = get_item_value(item, "answer")
|
||||
if gt_answer_val is None:
|
||||
answers_list = get_item_value(item, "answers")
|
||||
if isinstance(answers_list, list) and answers_list:
|
||||
raw_ans = answers_list[0]
|
||||
if isinstance(raw_ans, (str, int, float, bool)):
|
||||
gt_answer_val = raw_ans
|
||||
else:
|
||||
gt_answer_val = str(raw_ans)
|
||||
elif not isinstance(gt_answer_val, (str, int, float, bool)):
|
||||
gt_answer_val = str(gt_answer_val)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not safely get ground truth for item {i}: {e}")
|
||||
gt_answer_val = "ERROR_GETTING_ANSWER"
|
||||
item_data["ground_truth"] = gt_answer_val
|
||||
|
||||
eval_score_val = None
|
||||
try:
|
||||
eval_score_val = get_item_value(item, "score")
|
||||
if not isinstance(eval_score_val, (str, int, float, bool, type(None))):
|
||||
eval_score_val = str(eval_score_val)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not safely get score for item {i}: {e}")
|
||||
eval_score_val = "ERROR_GETTING_SCORE"
|
||||
item_data["eval_score"] = eval_score_val
|
||||
|
||||
debug_data.append(item_data)
|
||||
|
||||
with open(debug_file, "w", encoding="utf-8") as f:
|
||||
json.dump(debug_data, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"✅ Debug information saved to {debug_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving debug file '{debug_file}': {e}", exc_info=True)
|
||||
|
||||
|
||||
def research(args: argparse.Namespace, config: Config):
|
||||
"""Main function to run the research evaluation pipeline."""
|
||||
logger.info("🚀 Starting research pipeline execution...")
|
||||
start_time = time.time()
|
||||
|
||||
test_data = load_dataset_items(config, args.split)
|
||||
if not test_data:
|
||||
logger.error("Failed to load test data. Exiting.")
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("🏗️ Building ReSearchPipeline...")
|
||||
pipeline = ReSearchPipeline(config)
|
||||
logger.info("✅ Pipeline built successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize ReSearchPipeline: {e}", exc_info=True)
|
||||
return
|
||||
|
||||
agent_max_generations = getattr(config, "agent_max_generations", 32)
|
||||
generator_max_output_len = getattr(config, "generator_max_output_len", 24576)
|
||||
|
||||
try:
|
||||
logger.info("🏃 Starting pipeline run...")
|
||||
result_dataset = pipeline.run(test_data, do_eval=True)
|
||||
logger.info("✅ Pipeline run completed.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during pipeline run: {e}", exc_info=True)
|
||||
result_dataset = test_data
|
||||
logger.warning("Pipeline run failed, attempting to save inputs/partial results.")
|
||||
|
||||
run_duration = time.time() - start_time
|
||||
logger.info(f"Total run duration: {run_duration:.2f} seconds.")
|
||||
save_results(args, config, result_dataset, run_duration)
|
||||
|
||||
logger.info("🏁 Research pipeline execution finished.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Running ReSearch Evaluation Pipeline")
|
||||
parser.add_argument(
|
||||
"--config_path", type=str, default="./eval_config.yaml", help="Path to the main FlashRAG config file."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
default="bamboogle",
|
||||
help="Name of the dataset (must match config or data_dir structure).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split", type=str, default="test", help="Dataset split to evaluate (e.g., test, validation)."
|
||||
)
|
||||
parser.add_argument("--save_dir", type=str, default="./output_logs", help="Directory to save logs and results.")
|
||||
parser.add_argument("--save_note", type=str, default="research_run", help="A note to prepend to saved filenames.")
|
||||
|
||||
parser.add_argument("--data_dir", type=str, help="Override data directory specified in config.")
|
||||
parser.add_argument(
|
||||
"--sgl_remote_url", type=str, help="Override SGLang remote generator URL (e.g., localhost:8002)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remote_retriever_url", type=str, help="Override remote retriever URL (e.g., localhost:8001)."
|
||||
)
|
||||
parser.add_argument("--generator_model_path", type=str, help="Override generator model path specified in config.")
|
||||
parser.add_argument("--retriever_topk", type=int, help="Override retriever top K.")
|
||||
parser.add_argument("--generator_max_output_len", type=int, help="Override generator max output length.")
|
||||
parser.add_argument("--agent_max_generations", type=int, help="Override agent max interaction turns.")
|
||||
|
||||
args = parser.parse_args()
|
||||
logger.info(f"Starting evaluation script with arguments: {args}")
|
||||
|
||||
try:
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
logger.info(f"💾 Logs and results will be saved to: {args.save_dir}")
|
||||
except OSError as e:
|
||||
logger.error(f"Could not create save directory '{args.save_dir}': {e}", exc_info=True)
|
||||
exit(1)
|
||||
|
||||
config_overrides = {
|
||||
k: v
|
||||
for k, v in vars(args).items()
|
||||
if v is not None
|
||||
and k
|
||||
not in [
|
||||
"config_path",
|
||||
"dataset_name",
|
||||
"split",
|
||||
"save_dir",
|
||||
"save_note",
|
||||
]
|
||||
}
|
||||
|
||||
logger.info(f"🔧 Loading configuration from: {args.config_path}")
|
||||
try:
|
||||
config = Config(args.config_path, config_dict=config_overrides)
|
||||
config.dataset_name = args.dataset_name
|
||||
if args.data_dir:
|
||||
config.data_dir = args.data_dir
|
||||
|
||||
logger.info(f"Effective data_dir: {getattr(config, 'data_dir', 'N/A')}")
|
||||
logger.info(f"Effective generator_model_path: {getattr(config, 'generator_model_path', 'N/A')}")
|
||||
logger.info(f"Effective sgl_remote_url: {getattr(config, 'sgl_remote_url', 'N/A')}")
|
||||
logger.info(f"Effective remote_retriever_url: {getattr(config, 'remote_retriever_url', 'N/A')}")
|
||||
|
||||
logger.info("✅ Config loaded and potentially overridden by CLI arguments.")
|
||||
|
||||
config["dataset_path"] = os.path.join(config.data_dir, config.dataset_name)
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Config file not found at '{args.config_path}'. Please check the path.")
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading or processing configuration: {e}", exc_info=True)
|
||||
exit(1)
|
||||
|
||||
research(args, config)
|
@ -1,68 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
import zipfile
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from config import DATA_DIR
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""Parse command line arguments.
|
||||
|
||||
Returns:
|
||||
argparse.Namespace: Parsed arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Download FlashRAG datasets from HuggingFace Hub")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default="RUC-NLPIR/FlashRAG_datasets",
|
||||
help="HuggingFace repository IDs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
type=str,
|
||||
default=DATA_DIR / "flashrag_datasets",
|
||||
help="Local directory to save model",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to download model."""
|
||||
args = parse_args()
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Configuration
|
||||
HF_TOKEN = os.getenv("HF_TOKEN")
|
||||
|
||||
ALLOW_PATTERNS = [
|
||||
"*retrieval-corpus*",
|
||||
"*bamboogle*",
|
||||
"*nq*",
|
||||
]
|
||||
|
||||
# Download the model
|
||||
snapshot_download(
|
||||
token=HF_TOKEN,
|
||||
repo_id=args.repo_id,
|
||||
local_dir=args.local_dir,
|
||||
repo_type="dataset",
|
||||
# ignore_patterns=IGNORE_PATTERNS,
|
||||
allow_patterns=ALLOW_PATTERNS,
|
||||
)
|
||||
|
||||
# unzip data/flashrag_datasets/retrieval-corpus/wiki18_100w.zip
|
||||
print("Unzipping wiki18_100w.zip. Might take a while...")
|
||||
zip_file_path = os.path.join(args.local_dir, "retrieval-corpus", "wiki18_100w.zip")
|
||||
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
|
||||
zip_ref.extractall(args.local_dir)
|
||||
|
||||
print(f"✅ Done: {args.repo_id} -> {args.local_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,43 +0,0 @@
|
||||
"""Download and extract FlashRAG index."""
|
||||
|
||||
import os
|
||||
import zipfile
|
||||
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
from config import DATA_DIR
|
||||
|
||||
# Constants
|
||||
URL = "https://www.modelscope.cn/datasets/hhjinjiajie/FlashRAG_Dataset/resolve/master/retrieval_corpus/wiki18_100w_e5_index.zip"
|
||||
ZIP_NAME = "wiki18_100w_e5_index.zip"
|
||||
zip_path = DATA_DIR / ZIP_NAME
|
||||
|
||||
# Download with progress bar
|
||||
print("📥 Downloading index...")
|
||||
response = requests.get(URL, stream=True)
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
|
||||
with (
|
||||
open(zip_path, "wb") as f,
|
||||
tqdm(
|
||||
desc=ZIP_NAME,
|
||||
total=total_size,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as bar,
|
||||
):
|
||||
for data in response.iter_content(chunk_size=1024):
|
||||
size = f.write(data)
|
||||
bar.update(size)
|
||||
|
||||
# Extract
|
||||
print("📦 Extracting index...")
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall(DATA_DIR)
|
||||
|
||||
# Clean up zip
|
||||
os.remove(zip_path)
|
||||
print("✅ Download and extraction completed successfully!")
|
||||
print(f"Index file is at: {DATA_DIR}/data00/jiajie_jin/flashrag_indexes/wiki_dpr_100w/e5_flat_inner.index")
|
@ -1,54 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from config import GENERATOR_MODEL_DIR, GENERATOR_MODEL_REPO_ID
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""Parse command line arguments.
|
||||
|
||||
Returns:
|
||||
argparse.Namespace: Parsed arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Download model from HuggingFace Hub")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default=GENERATOR_MODEL_REPO_ID,
|
||||
help="HuggingFace repository ID",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
type=str,
|
||||
default=GENERATOR_MODEL_DIR,
|
||||
help="Local directory to save model",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to download model."""
|
||||
args = parse_args()
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Configuration
|
||||
HF_TOKEN = os.getenv("HF_TOKEN")
|
||||
|
||||
print("Downloading model to", args.local_dir)
|
||||
|
||||
# Download the model
|
||||
snapshot_download(
|
||||
token=HF_TOKEN,
|
||||
repo_id=args.repo_id,
|
||||
local_dir=args.local_dir,
|
||||
repo_type="model",
|
||||
)
|
||||
print(f"✅ Done: {args.repo_id} -> {args.local_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,54 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from config import RETRIEVER_MODEL_DIR, RETRIEVER_MODEL_REPO_ID
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""Parse command line arguments.
|
||||
|
||||
Returns:
|
||||
argparse.Namespace: Parsed arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Download model from HuggingFace Hub")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default=RETRIEVER_MODEL_REPO_ID,
|
||||
help="HuggingFace repository ID",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
type=str,
|
||||
default=RETRIEVER_MODEL_DIR,
|
||||
help="Local directory to save model",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to download model."""
|
||||
args = parse_args()
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Configuration
|
||||
HF_TOKEN = os.getenv("HF_TOKEN")
|
||||
|
||||
print("Downloading model to", args.local_dir)
|
||||
|
||||
# Download the model
|
||||
snapshot_download(
|
||||
token=HF_TOKEN,
|
||||
repo_id=args.repo_id,
|
||||
local_dir=args.local_dir,
|
||||
repo_type="model",
|
||||
)
|
||||
print(f"✅ Done: {args.repo_id} -> {args.local_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,9 +0,0 @@
|
||||
# ------------------------------------------------Environment Settings------------------------------------------------#
|
||||
gpu_id: "0"
|
||||
|
||||
# -------------------------------------------------Retrieval Settings------------------------------------------------#
|
||||
# If set the name, the model path will be find in global paths
|
||||
retrieval_method: "e5" # name or path of the retrieval model.
|
||||
index_path: "/mnt/nas/thinhlpg/data/data00/jiajie_jin/flashrag_indexes/wiki_dpr_100w/e5_flat_inner.index" # path to the indexed file
|
||||
faiss_gpu: False # whether use gpu to hold index
|
||||
corpus_path: "/mnt/nas/thinhlpg/code/DeepSearch/data/flashrag_datasets/wiki18_100w.jsonl" # path to corpus in '.jsonl' format that store the documents
|
@ -1,127 +0,0 @@
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to sys.path to allow importing config
|
||||
# Assuming the script is at DeepSearch/scripts/serving/serve_generator.py
|
||||
# The project root (DeepSearch) is parents[2]
|
||||
PROJ_ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(PROJ_ROOT) not in sys.path:
|
||||
sys.path.append(str(PROJ_ROOT))
|
||||
|
||||
# Import after adjusting sys.path
|
||||
try:
|
||||
from config import (
|
||||
GENERATOR_MODEL_REPO_ID,
|
||||
GENERATOR_SERVER_PORT,
|
||||
MODEL_CONFIG,
|
||||
logger,
|
||||
)
|
||||
except ImportError as e:
|
||||
# Use print here as logger might not be available if import failed
|
||||
print(
|
||||
f"Error importing config: {e}. Make sure config.py is in the project root ({PROJ_ROOT}) and added to sys.path."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def launch_sglang_server(
|
||||
model_id: str,
|
||||
port: int,
|
||||
context_length: int,
|
||||
host: str = "0.0.0.0",
|
||||
dtype: str = "bfloat16",
|
||||
) -> None:
|
||||
"""Launches the SGLang server using specified configurations.
|
||||
|
||||
Args:
|
||||
model_id: The Hugging Face repository ID of the model.
|
||||
port: The port number for the server.
|
||||
context_length: The maximum context length for the model.
|
||||
host: The host address for the server.
|
||||
dtype: The data type for the model (e.g., 'bfloat16', 'float16').
|
||||
"""
|
||||
command = [
|
||||
sys.executable, # Use the current Python interpreter
|
||||
"-m",
|
||||
"sglang.launch_server",
|
||||
"--model-path",
|
||||
model_id,
|
||||
"--context-length",
|
||||
str(context_length),
|
||||
"--enable-metrics",
|
||||
"--dtype",
|
||||
dtype,
|
||||
"--host",
|
||||
host,
|
||||
"--port",
|
||||
str(port),
|
||||
"--mem-fraction-static",
|
||||
"0.5",
|
||||
"--trust-remote-code",
|
||||
# Recommended by SGLang for stability sometimes
|
||||
"--disable-overlap",
|
||||
# Can sometimes cause issues
|
||||
"--disable-radix-cache",
|
||||
]
|
||||
|
||||
# Log the command clearly
|
||||
command_str = " ".join(command)
|
||||
logger.info(f"🚀 Launching SGLang server with command: {command_str}")
|
||||
|
||||
process = None # Initialize process to None
|
||||
try:
|
||||
# Use Popen to start the server process
|
||||
# It runs in the foreground relative to this script,
|
||||
# but allows us to catch KeyboardInterrupt cleanly.
|
||||
process = subprocess.Popen(command)
|
||||
# Wait for the process to complete (e.g., user interruption)
|
||||
process.wait()
|
||||
# Check return code after waiting
|
||||
if process.returncode != 0:
|
||||
logger.error(f"💥 SGLang server process exited with error code: {process.returncode}")
|
||||
sys.exit(process.returncode)
|
||||
else:
|
||||
logger.info("✅ SGLang server process finished gracefully.")
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"💥 Error: Python executable or sglang module not found.")
|
||||
logger.error(f"Ensure '{sys.executable}' is correct and sglang is installed.")
|
||||
sys.exit(1)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("🛑 SGLang server launch interrupted by user. Stopping server...")
|
||||
# Attempt to terminate the process gracefully
|
||||
if process and process.poll() is None: # Check if process exists and is running
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=5) # Wait a bit for termination
|
||||
logger.info("✅ Server terminated gracefully.")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("⚠️ Server did not terminate gracefully, forcing kill.")
|
||||
process.kill()
|
||||
sys.exit(0) # Exit cleanly after interrupt
|
||||
except Exception as e:
|
||||
# Catch any other unexpected exceptions during launch or waiting
|
||||
logger.error(f"🚨 An unexpected error occurred: {e}")
|
||||
# Ensure process is cleaned up if it exists
|
||||
if process and process.poll() is None:
|
||||
process.kill()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Get context length from config, default to 8192
|
||||
context_len = MODEL_CONFIG.get("max_seq_length", 8192)
|
||||
|
||||
logger.info("----------------------------------------------------")
|
||||
logger.info("✨ Starting SGLang Generator Server ✨")
|
||||
logger.info(f" Model ID: {GENERATOR_MODEL_REPO_ID}")
|
||||
logger.info(f" Port: {GENERATOR_SERVER_PORT}")
|
||||
logger.info(f" Context Length: {context_len}")
|
||||
logger.info("----------------------------------------------------")
|
||||
|
||||
launch_sglang_server(
|
||||
model_id=GENERATOR_MODEL_REPO_ID,
|
||||
port=GENERATOR_SERVER_PORT,
|
||||
context_length=context_len,
|
||||
)
|
@ -1,113 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from flashrag.config import Config
|
||||
from flashrag.utils import get_retriever
|
||||
from pydantic import BaseModel
|
||||
|
||||
from config import RETRIEVER_SERVER_PORT
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
retriever_list = []
|
||||
available_retrievers = deque()
|
||||
retriever_semaphore = None
|
||||
|
||||
|
||||
def init_retriever(args):
|
||||
global retriever_semaphore
|
||||
config = Config(args.config)
|
||||
for i in range(args.num_retriever):
|
||||
print(f"Initializing retriever {i + 1}/{args.num_retriever}")
|
||||
retriever = get_retriever(config)
|
||||
retriever_list.append(retriever)
|
||||
available_retrievers.append(i)
|
||||
# create a semaphore to limit the number of retrievers that can be used concurrently
|
||||
retriever_semaphore = asyncio.Semaphore(args.num_retriever)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {"status": "healthy", "retrievers": {"total": len(retriever_list), "available": len(available_retrievers)}}
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
query: str
|
||||
top_n: int = 10
|
||||
return_score: bool = False
|
||||
|
||||
|
||||
class BatchQueryRequest(BaseModel):
|
||||
query: List[str]
|
||||
top_n: int = 10
|
||||
return_score: bool = False
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
id: str
|
||||
contents: str
|
||||
|
||||
|
||||
@app.post("/search", response_model=Union[Tuple[List[Document], List[float]], List[Document]])
|
||||
async def search(request: QueryRequest):
|
||||
query = request.query
|
||||
top_n = request.top_n
|
||||
return_score = request.return_score
|
||||
|
||||
if not query or not query.strip():
|
||||
print(f"Query content cannot be empty: {query}")
|
||||
raise HTTPException(status_code=400, detail="Query content cannot be empty")
|
||||
|
||||
async with retriever_semaphore:
|
||||
retriever_idx = available_retrievers.popleft()
|
||||
try:
|
||||
if return_score:
|
||||
results, scores = retriever_list[retriever_idx].search(query, top_n, return_score)
|
||||
return [Document(id=result["id"], contents=result["contents"]) for result in results], scores
|
||||
else:
|
||||
results = retriever_list[retriever_idx].search(query, top_n, return_score)
|
||||
return [Document(id=result["id"], contents=result["contents"]) for result in results]
|
||||
finally:
|
||||
available_retrievers.append(retriever_idx)
|
||||
|
||||
|
||||
@app.post("/batch_search", response_model=Union[List[List[Document]], Tuple[List[List[Document]], List[List[float]]]])
|
||||
async def batch_search(request: BatchQueryRequest):
|
||||
query = request.query
|
||||
top_n = request.top_n
|
||||
return_score = request.return_score
|
||||
|
||||
async with retriever_semaphore:
|
||||
retriever_idx = available_retrievers.popleft()
|
||||
try:
|
||||
if return_score:
|
||||
results, scores = retriever_list[retriever_idx].batch_search(query, top_n, return_score)
|
||||
return [
|
||||
[Document(id=result["id"], contents=result["contents"]) for result in results[i]]
|
||||
for i in range(len(results))
|
||||
], scores
|
||||
else:
|
||||
results = retriever_list[retriever_idx].batch_search(query, top_n, return_score)
|
||||
return [
|
||||
[Document(id=result["id"], contents=result["contents"]) for result in results[i]]
|
||||
for i in range(len(results))
|
||||
]
|
||||
finally:
|
||||
available_retrievers.append(retriever_idx)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default="./retriever_config.yaml")
|
||||
parser.add_argument("--num_retriever", type=int, default=1)
|
||||
parser.add_argument("--port", type=int, default=RETRIEVER_SERVER_PORT)
|
||||
args = parser.parse_args()
|
||||
|
||||
init_retriever(args)
|
||||
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
@ -1,135 +0,0 @@
|
||||
import json
|
||||
import math # Import math for ceiling division
|
||||
import sys
|
||||
import traceback # Import traceback
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
# Add project root to Python path if needed (adjust relative path as necessary)
|
||||
project_root = Path(__file__).resolve().parent.parent
|
||||
sys.path.append(str(project_root))
|
||||
|
||||
from src.embeddings import CustomHuggingFaceEmbeddings
|
||||
|
||||
# Import FAISS after potentially adding to sys.path
|
||||
try:
|
||||
from langchain_community.vectorstores import FAISS
|
||||
except ImportError:
|
||||
print("Error: langchain_community or FAISS not installed. Please install with 'pip install langchain faiss-cpu'")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def build_faiss_index_from_csv(csv_path: str, index_save_path: str, batch_size: int = 128) -> None:
|
||||
"""Builds a FAISS index from a CSV containing paragraph content and metadata.
|
||||
|
||||
Reads a CSV file, generates embeddings for the 'content' column in batches,
|
||||
and saves the FAISS index files (index.faiss, index.pkl) locally.
|
||||
|
||||
Args:
|
||||
csv_path: Path to the input CSV file (e.g., data/processed/paragraphs.csv).
|
||||
index_save_path: Path to the directory where the index files should be saved.
|
||||
batch_size: Number of texts to process in each embedding batch.
|
||||
"""
|
||||
print(f"Loading paragraphs from {csv_path}")
|
||||
try:
|
||||
df = pd.read_csv(csv_path)
|
||||
except FileNotFoundError:
|
||||
print(f"Error: CSV file not found at {csv_path}. Please run the extraction script first.")
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"Error reading CSV file: {e}")
|
||||
return
|
||||
|
||||
if "content" not in df.columns or "metadata" not in df.columns:
|
||||
print("Error: CSV file must contain 'content' and 'metadata' columns.")
|
||||
return
|
||||
|
||||
if df.empty:
|
||||
print("Warning: Input CSV file is empty. No index will be built.")
|
||||
return
|
||||
|
||||
# Prepare documents for FAISS
|
||||
texts = df["content"].astype(str).tolist()
|
||||
metadatas = []
|
||||
try:
|
||||
metadatas = [json.loads(m) for m in df["metadata"].tolist()]
|
||||
print(f"Prepared {len(texts)} texts and {len(metadatas)} metadatas.")
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error parsing metadata JSON: {e}. Check the format in {csv_path}")
|
||||
traceback.print_exc() # Print traceback for JSON errors
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"Error processing metadata: {e}")
|
||||
traceback.print_exc() # Print traceback for other metadata errors
|
||||
return
|
||||
|
||||
if not texts or not metadatas or len(texts) != len(metadatas):
|
||||
print(f"Error: Mismatch or empty texts/metadatas. Texts: {len(texts)}, Metadatas: {len(metadatas)}")
|
||||
return
|
||||
|
||||
print("Initializing embeddings model...")
|
||||
try:
|
||||
embeddings = CustomHuggingFaceEmbeddings()
|
||||
except Exception as e:
|
||||
print(f"Error initializing embeddings model: {e}")
|
||||
traceback.print_exc()
|
||||
return
|
||||
print("Embeddings model initialized successfully.")
|
||||
|
||||
vectorstore = None
|
||||
num_batches = math.ceil(len(texts) / batch_size)
|
||||
print(f"Processing {len(texts)} texts in {num_batches} batches of size {batch_size}...")
|
||||
|
||||
for i in range(num_batches):
|
||||
start_idx = i * batch_size
|
||||
end_idx = min((i + 1) * batch_size, len(texts))
|
||||
batch_texts = texts[start_idx:end_idx]
|
||||
batch_metadatas = metadatas[start_idx:end_idx]
|
||||
print(f" Processing batch {i + 1}/{num_batches} (indices {start_idx}-{end_idx - 1})...")
|
||||
|
||||
try:
|
||||
if i == 0:
|
||||
# Initialize the vector store with the first batch
|
||||
print(f" Initializing FAISS index with first batch...")
|
||||
vectorstore = FAISS.from_texts(texts=batch_texts, embedding=embeddings, metadatas=batch_metadatas)
|
||||
print(" FAISS index initialized.")
|
||||
else:
|
||||
# Add subsequent batches to the existing store
|
||||
if vectorstore is None:
|
||||
print("Error: vectorstore is None after first batch, cannot add more texts.")
|
||||
return # Should not happen if first batch succeeded
|
||||
print(f" Adding batch {i + 1} to FAISS index...")
|
||||
vectorstore.add_texts(texts=batch_texts, metadatas=batch_metadatas)
|
||||
print(f" Batch {i + 1} added.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing batch {i + 1} (indices {start_idx}-{end_idx - 1}): {e}")
|
||||
traceback.print_exc()
|
||||
print("Stopping index creation due to error in batch processing.")
|
||||
return # Exit if any batch fails
|
||||
|
||||
if vectorstore is None:
|
||||
print("Error: Failed to create or add any data to the vectorstore.")
|
||||
return
|
||||
|
||||
# Save the completed index
|
||||
try:
|
||||
print(f"Attempting to save final FAISS index files to directory: {index_save_path}")
|
||||
# Ensure the target directory exists before saving
|
||||
Path(index_save_path).mkdir(parents=True, exist_ok=True)
|
||||
vectorstore.save_local(index_save_path)
|
||||
print(f"Successfully saved final FAISS index files (index.faiss, index.pkl) to: {index_save_path}")
|
||||
except Exception as e:
|
||||
print(f"Error during final vectorstore.save_local to {index_save_path}: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Define paths relative to this script or use absolute paths
|
||||
PROCESSED_DIR = Path("data/processed")
|
||||
INPUT_CSV = str(PROCESSED_DIR / "paragraphs.csv")
|
||||
# FAISS save_local will save index.faiss and index.pkl in this directory
|
||||
INDEX_SAVE_DIR = str(PROCESSED_DIR) # Save directly to processed dir
|
||||
|
||||
build_faiss_index_from_csv(INPUT_CSV, INDEX_SAVE_DIR, batch_size=128)
|
@ -1,30 +0,0 @@
|
||||
#!/bin/bash
|
||||
# This script is taken from https://github.com/StonyBrookNLP/musique with slight modifications
|
||||
|
||||
set -e
|
||||
set -x
|
||||
|
||||
# If gdown doesn't work, you can download files from mentioned URLs manually
|
||||
# and put them at appropriate locations.
|
||||
pip install gdown
|
||||
|
||||
ZIP_NAME="musique_v1.0.zip"
|
||||
|
||||
# URL: https://drive.google.com/file/d/1tGdADlNjWFaHLeZZGShh2IRcpO6Lv24h/view?usp=sharing
|
||||
gdown --id 1tGdADlNjWFaHLeZZGShh2IRcpO6Lv24h --output $ZIP_NAME
|
||||
|
||||
TARGET_DIR="./data/raw"
|
||||
mkdir -p $TARGET_DIR
|
||||
unzip -o $(basename $ZIP_NAME) -d $TARGET_DIR # Extract directly into target
|
||||
|
||||
# Move contents from the extracted 'data' folder up one level
|
||||
mv $TARGET_DIR/data/* $TARGET_DIR/
|
||||
|
||||
# Clean up the empty directory and the zip
|
||||
rm -rf $TARGET_DIR/data
|
||||
rm $ZIP_NAME
|
||||
|
||||
# TODO: prevent these from zipping in.
|
||||
rm -rf __MACOSX
|
||||
# Clean up potential extracted .DS_Store
|
||||
rm -f $TARGET_DIR/.DS_Store
|
@ -1,101 +0,0 @@
|
||||
import json
|
||||
import sys
|
||||
from collections import defaultdict # Use defaultdict for cleaner accumulation
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
|
||||
# Add project root to Python path if needed (adjust relative path as necessary)
|
||||
# project_root = Path(__file__).resolve().parent.parent
|
||||
# sys.path.append(str(project_root))
|
||||
# from config import logger # Assuming you have a logger setup
|
||||
|
||||
|
||||
def extract_unique_paragraphs(input_paths: list[str], output_csv_path: str) -> None:
|
||||
"""Extracts unique paragraphs from specified JSONL files.
|
||||
|
||||
Reads Musique JSONL files (train, dev, test), finds unique paragraphs
|
||||
(regardless of is_supporting flag), combines title and text,
|
||||
tracks source question IDs, and saves to CSV.
|
||||
|
||||
Args:
|
||||
input_paths: A list of paths to the input JSONL files.
|
||||
output_csv_path: Path to save the output CSV file.
|
||||
"""
|
||||
output_dir = Path(output_csv_path).parent
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Use paragraph content as key, value is the set of source question IDs
|
||||
paragraphs_data = defaultdict(set)
|
||||
print("Starting paragraph extraction (including non-supporting)...")
|
||||
|
||||
for file_path in input_paths:
|
||||
print(f"Processing file: {file_path}")
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as infile:
|
||||
for line_num, line in enumerate(infile, 1):
|
||||
try:
|
||||
data = json.loads(line)
|
||||
main_question_id = data.get("id")
|
||||
if not main_question_id:
|
||||
print(f"Warning: Missing 'id' in line {line_num} of {file_path}")
|
||||
continue
|
||||
|
||||
for p in data.get("paragraphs", []):
|
||||
title = p.get("title", "No Title")
|
||||
text = p.get("paragraph_text", "")
|
||||
content = f"{title}\n{text}".strip()
|
||||
|
||||
if not content:
|
||||
continue # Skip empty paragraphs
|
||||
|
||||
paragraphs_data[content].add(main_question_id)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
print(f"Warning: Skipping invalid JSON in line {line_num} of {file_path}")
|
||||
except Exception as e:
|
||||
print(f"Warning: Error processing line {line_num} in {file_path}: {e}")
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Input file not found: {file_path}")
|
||||
except Exception as e:
|
||||
print(f"Error reading file {file_path}: {e}")
|
||||
|
||||
print(f"Found {len(paragraphs_data)} unique paragraphs (supporting and non-supporting).")
|
||||
|
||||
# Prepare data for DataFrame
|
||||
output_list = []
|
||||
sorted_content = sorted(paragraphs_data.keys())
|
||||
for chunk_id, content in enumerate(sorted_content, 1):
|
||||
question_ids = paragraphs_data[content]
|
||||
metadata = {"source_question_ids": sorted(list(question_ids))}
|
||||
output_list.append(
|
||||
{
|
||||
"chunk_id": chunk_id,
|
||||
"content": content,
|
||||
"metadata": json.dumps(metadata), # Store metadata as JSON string
|
||||
}
|
||||
)
|
||||
|
||||
if not output_list:
|
||||
print("No paragraphs found to save.")
|
||||
return
|
||||
df = pd.DataFrame(output_list)
|
||||
try:
|
||||
df.to_csv(output_csv_path, index=False)
|
||||
print(f"Successfully saved unique paragraphs to {output_csv_path}")
|
||||
except Exception as e:
|
||||
print(f"Error saving CSV file: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
RAW_DIR = Path("data/raw")
|
||||
PROCESSED_DIR = Path("data/processed")
|
||||
|
||||
input_files = [
|
||||
str(RAW_DIR / "musique_ans_v1.0_train.jsonl"),
|
||||
str(RAW_DIR / "musique_ans_v1.0_dev.jsonl"),
|
||||
str(RAW_DIR / "musique_ans_v1.0_test.jsonl"),
|
||||
]
|
||||
output_csv = str(PROCESSED_DIR / "paragraphs.csv")
|
||||
|
||||
extract_unique_paragraphs(input_files, output_csv)
|
@ -1,155 +0,0 @@
|
||||
"""Prepares a deterministic sampled dev set (questions_dev.jsonl) from raw Musique dev data."""
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def transform_musique_dev_data(input_path: str, output_path: str, sample_config: dict) -> None:
|
||||
"""Transforms Musique dev data with deterministic stratified sampling using uniform selection from sorted lists.
|
||||
|
||||
Reads dev data, categorizes by hop type (2, 3, 4), sorts categories by ID,
|
||||
selects N samples uniformly spaced from each sorted category based on sample_config,
|
||||
combines, sorts final list by ID, combines answers/aliases, extracts supporting paras,
|
||||
and writes the transformed data to output_path.
|
||||
|
||||
Args:
|
||||
input_path: Path to the input JSONL file (e.g., data/raw/musique_ans_v1.0_dev.jsonl).
|
||||
output_path: Path to the output JSONL file (e.g., data/processed/questions_dev.jsonl).
|
||||
sample_config: Dictionary specifying samples per hop type (e.g., {"2hop": 20, "3hop": 15, "4hop": 15}).
|
||||
"""
|
||||
output_dir = Path(output_path).parent
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Reading all data from {input_path} for dev sampling...")
|
||||
all_data = []
|
||||
try:
|
||||
with open(input_path, "r", encoding="utf-8") as infile:
|
||||
for line_num, line in enumerate(infile, 1):
|
||||
try:
|
||||
data = json.loads(line)
|
||||
if "id" in data:
|
||||
all_data.append(data)
|
||||
else:
|
||||
print(f"Warning: Skipping line {line_num} due to missing 'id' field in {input_path}")
|
||||
except json.JSONDecodeError:
|
||||
print(f"Warning: Skipping invalid JSON in line {line_num} of {input_path}")
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Input file not found at {input_path}")
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"Error reading file {input_path}: {e}")
|
||||
return
|
||||
print(f"Read {len(all_data)} total samples from dev set.")
|
||||
|
||||
# Categorize data by hop count (2hop, 3hop, 4hop)
|
||||
categorized_data = defaultdict(list)
|
||||
print("Categorizing data by hop type (2, 3, 4)...")
|
||||
for data in all_data:
|
||||
q_id = data["id"]
|
||||
hop_type = None
|
||||
if q_id.startswith("2hop"):
|
||||
hop_type = "2hop"
|
||||
elif q_id.startswith("3hop"):
|
||||
hop_type = "3hop"
|
||||
elif q_id.startswith("4hop"):
|
||||
hop_type = "4hop"
|
||||
|
||||
if hop_type:
|
||||
categorized_data[hop_type].append(data)
|
||||
|
||||
# Deterministic sampling using sorting and uniform index selection
|
||||
final_sample_list = []
|
||||
total_target = sum(sample_config.values())
|
||||
print(f"Sampling deterministically via uniform selection from sorted lists to get {total_target} dev samples...")
|
||||
|
||||
for hop_type, target_count in sample_config.items():
|
||||
available_samples = categorized_data.get(hop_type, [])
|
||||
current_count = len(available_samples)
|
||||
print(f" {hop_type}: Found {current_count} samples, need {target_count}.")
|
||||
|
||||
if current_count == 0:
|
||||
continue
|
||||
|
||||
available_samples.sort(key=lambda x: x["id"])
|
||||
selected_samples_for_hop = []
|
||||
if current_count < target_count:
|
||||
print(f" Warning: Not enough samples for {hop_type}. Taking all {current_count} sorted samples.")
|
||||
selected_samples_for_hop = available_samples
|
||||
elif target_count > 0: # Ensure target_count is positive before selecting
|
||||
print(f" Selecting {target_count} samples uniformly from {current_count}...")
|
||||
# Calculate indices using integer interpretation of evenly spaced points
|
||||
indices_to_take = [
|
||||
int(i * (current_count - 1) / (target_count - 1)) if target_count > 1 else 0
|
||||
for i in range(target_count)
|
||||
] # Adjust index calc for edges
|
||||
indices_to_take = sorted(list(set(indices_to_take))) # Ensure unique indices
|
||||
# Simple fallback if uniqueness reduced count below target
|
||||
while len(indices_to_take) < target_count and len(indices_to_take) < current_count:
|
||||
next_val = indices_to_take[-1] + 1
|
||||
if next_val < current_count:
|
||||
indices_to_take.append(next_val)
|
||||
else: # Cannot add more unique indices
|
||||
break
|
||||
selected_samples_for_hop = [
|
||||
available_samples[idx] for idx in indices_to_take[:target_count]
|
||||
] # Select based on unique indices, capped at target
|
||||
|
||||
final_sample_list.extend(selected_samples_for_hop)
|
||||
|
||||
print(f"Selected {len(final_sample_list)} dev samples in total.")
|
||||
|
||||
# Sort the final combined list by ID for consistent output order
|
||||
print("Sorting the final combined dev sample list by ID...")
|
||||
final_sample_list.sort(key=lambda x: x["id"])
|
||||
|
||||
# Process and write the selected samples
|
||||
print(f"Processing and writing {len(final_sample_list)} selected dev samples to {output_path}...")
|
||||
count = 0
|
||||
try:
|
||||
with open(output_path, "w", encoding="utf-8") as outfile:
|
||||
for data in final_sample_list:
|
||||
try:
|
||||
supporting_paragraphs = [
|
||||
p["paragraph_text"] for p in data.get("paragraphs", []) if p.get("is_supporting", False)
|
||||
]
|
||||
main_answer = data.get("answer", "")
|
||||
aliases = data.get("answer_aliases", [])
|
||||
all_answers = [main_answer] + (aliases if isinstance(aliases, list) else [])
|
||||
valid_answers = [str(ans).strip() for ans in all_answers if ans and str(ans).strip()]
|
||||
unique_valid_answers = list(set(valid_answers)) # Keep unique, don't sort alphabetically
|
||||
combined_answer_str = " OR ".join(unique_valid_answers)
|
||||
|
||||
output_data = {
|
||||
"id": data.get("id"),
|
||||
"question": data.get("question"),
|
||||
"answer": combined_answer_str,
|
||||
"supporting_paragraphs": supporting_paragraphs,
|
||||
}
|
||||
outfile.write(json.dumps(output_data) + "\n")
|
||||
count += 1
|
||||
except KeyError as e:
|
||||
print(f"Skipping sample due to missing key {e}: {data.get('id')}")
|
||||
print(f"Successfully processed and wrote {count} dev samples.")
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred during writing: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Define file paths relative to the project root
|
||||
# Ensure this script is run from the project root or adjust paths accordingly
|
||||
RAW_DIR = Path("data/raw")
|
||||
PROCESSED_DIR = Path("data/processed")
|
||||
|
||||
# Define sampling configuration for the dev set
|
||||
DEV_SAMPLING_CONFIG = {"2hop": 20, "3hop": 15, "4hop": 15} # Total = 50
|
||||
|
||||
INPUT_FILE = RAW_DIR / "musique_ans_v1.0_dev.jsonl"
|
||||
OUTPUT_FILE = PROCESSED_DIR / "questions_dev.jsonl"
|
||||
|
||||
transform_musique_dev_data(str(INPUT_FILE), str(OUTPUT_FILE), DEV_SAMPLING_CONFIG)
|
||||
|
||||
print(f"\nMusique DEV JSONL transformation and deterministic sampling complete.")
|
@ -1,172 +0,0 @@
|
||||
import json
|
||||
import math # Keep math import
|
||||
import os
|
||||
import re # Import re for parsing ID
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
# import random # No longer needed
|
||||
# SEED = 42 # No longer needed
|
||||
# random.seed(SEED) # No longer needed
|
||||
|
||||
|
||||
def transform_musique_data(input_path: str, output_path: str, sample_config: dict) -> None:
|
||||
"""Transforms Musique data with deterministic stratified sampling using uniform selection from sorted lists.
|
||||
|
||||
Reads data, categorizes by detailed hop type, sorts categories by ID,
|
||||
selects N samples uniformly spaced from each sorted category,
|
||||
combines, sorts final list by ID, and writes to output.
|
||||
|
||||
Args:
|
||||
input_path: Path to the input JSONL file.
|
||||
output_path: Path to the output JSONL file.
|
||||
sample_config: Dictionary specifying samples per detailed hop type (e.g., {"2hop": 400, "3hop1": 150, ...}).
|
||||
"""
|
||||
output_dir = Path(output_path).parent
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"Reading all data from {input_path} for sampling...")
|
||||
all_data = []
|
||||
try:
|
||||
with open(input_path, "r", encoding="utf-8") as infile:
|
||||
for line_num, line in enumerate(infile, 1):
|
||||
try:
|
||||
data = json.loads(line)
|
||||
if "id" in data:
|
||||
all_data.append(data)
|
||||
else:
|
||||
print(f"Warning: Skipping line {line_num} due to missing 'id' field in {input_path}")
|
||||
except json.JSONDecodeError:
|
||||
print(f"Warning: Skipping invalid JSON in line {line_num} of {input_path}")
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Input file not found at {input_path}")
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"Error reading file {input_path}: {e}")
|
||||
return
|
||||
print(f"Read {len(all_data)} total samples with IDs.")
|
||||
|
||||
# Detailed Categorization by hop type
|
||||
categorized_data = defaultdict(list)
|
||||
print("Categorizing data by detailed hop type (e.g., 3hop1, 4hop2)...")
|
||||
for data in all_data:
|
||||
q_id = data["id"]
|
||||
match = re.match(r"^(2hop|3hop[12]|4hop[123])__", q_id)
|
||||
if match:
|
||||
detailed_hop_type = match.group(1)
|
||||
categorized_data[detailed_hop_type].append(data)
|
||||
# else: # Optional: log if an ID doesn't match expected pattern
|
||||
# print(f"Warning: ID {q_id} does not match expected hop pattern.")
|
||||
|
||||
# Deterministic sampling using sorting and uniform index selection
|
||||
final_sample_list = []
|
||||
total_target = sum(sample_config.values())
|
||||
print(f"Sampling deterministically via uniform selection from sorted lists to get {total_target} samples...")
|
||||
# Check if all requested hop types exist in config
|
||||
for hop_type in sample_config.keys():
|
||||
if hop_type not in categorized_data:
|
||||
print(f"Warning: Hop type '{hop_type}' requested in config but not found in data.")
|
||||
|
||||
for hop_type, target_count in sample_config.items():
|
||||
available_samples = categorized_data.get(hop_type, [])
|
||||
current_count = len(available_samples)
|
||||
print(f" {hop_type}: Found {current_count} samples, need {target_count}.")
|
||||
|
||||
if current_count == 0:
|
||||
continue
|
||||
|
||||
# Sort the list for this category by ID
|
||||
available_samples.sort(key=lambda x: x["id"])
|
||||
|
||||
selected_samples_for_hop = []
|
||||
if current_count < target_count:
|
||||
print(f" Warning: Not enough samples for {hop_type}. Taking all {current_count} sorted samples.")
|
||||
selected_samples_for_hop = available_samples
|
||||
else:
|
||||
# Select target_count indices spread uniformly across the available samples
|
||||
print(f" Selecting {target_count} samples uniformly from {current_count}...")
|
||||
# Calculate indices using integer interpretation of evenly spaced points
|
||||
indices_to_take = [int(i * current_count / target_count) for i in range(target_count)]
|
||||
# Ensure uniqueness in case of rounding issues with small numbers (though unlikely here)
|
||||
indices_to_take = sorted(list(set(indices_to_take)))
|
||||
# Adjust if rounding resulted in fewer than target_count unique indices
|
||||
while len(indices_to_take) < target_count:
|
||||
# This is a fallback, shouldn't happen if current_count >= target_count
|
||||
# Add indices from the end if needed, avoiding duplicates
|
||||
next_idx = indices_to_take[-1] + 1
|
||||
if next_idx < current_count and next_idx not in indices_to_take:
|
||||
indices_to_take.append(next_idx)
|
||||
else: # Should not be reachable if logic is sound
|
||||
break
|
||||
|
||||
# Select samples at the calculated indices
|
||||
selected_samples_for_hop = [
|
||||
available_samples[idx] for idx in indices_to_take[:target_count]
|
||||
] # Ensure we take exactly target_count
|
||||
|
||||
final_sample_list.extend(selected_samples_for_hop)
|
||||
|
||||
print(f"Selected {len(final_sample_list)} samples in total.")
|
||||
|
||||
# Sort the final combined list by ID for consistent output order
|
||||
print("Sorting the final combined sample list by ID...")
|
||||
final_sample_list.sort(key=lambda x: x["id"])
|
||||
|
||||
# Process and write the selected samples
|
||||
print(f"Processing and writing {len(final_sample_list)} selected samples to {output_path}...")
|
||||
count = 0
|
||||
try:
|
||||
with open(output_path, "w", encoding="utf-8") as outfile:
|
||||
for data in final_sample_list:
|
||||
try:
|
||||
supporting_paragraphs = [
|
||||
p["paragraph_text"] for p in data.get("paragraphs", []) if p.get("is_supporting", False)
|
||||
]
|
||||
|
||||
main_answer = data.get("answer", "")
|
||||
aliases = data.get("answer_aliases", [])
|
||||
|
||||
all_answers = [main_answer] + (aliases if isinstance(aliases, list) else [])
|
||||
valid_answers = [str(ans).strip() for ans in all_answers if ans and str(ans).strip()]
|
||||
unique_valid_answers = list(set(valid_answers))
|
||||
|
||||
combined_answer_str = " OR ".join(unique_valid_answers)
|
||||
|
||||
output_data = {
|
||||
"id": data.get("id"),
|
||||
"question": data.get("question"),
|
||||
"answer": combined_answer_str,
|
||||
"supporting_paragraphs": supporting_paragraphs,
|
||||
}
|
||||
outfile.write(json.dumps(output_data) + "\n")
|
||||
count += 1
|
||||
except KeyError as e:
|
||||
print(f"Skipping sample due to missing key {e}: {data.get('id')}")
|
||||
print(f"Successfully processed and wrote {count} samples.")
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred during writing: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Define file paths
|
||||
RAW_DIR = Path("data/raw")
|
||||
PROCESSED_DIR = Path("data/processed")
|
||||
|
||||
# Define detailed sampling configuration
|
||||
SAMPLING_CONFIG = {
|
||||
"2hop": 400,
|
||||
"3hop1": 150,
|
||||
"3hop2": 150,
|
||||
"4hop1": 100,
|
||||
"4hop2": 100,
|
||||
"4hop3": 100,
|
||||
} # Total = 1000
|
||||
|
||||
transform_musique_data(
|
||||
str(RAW_DIR / "musique_ans_v1.0_train.jsonl"), str(PROCESSED_DIR / "questions.jsonl"), SAMPLING_CONFIG
|
||||
)
|
||||
|
||||
print(
|
||||
"\nMusique JSONL transformation and detailed deterministic sampling (uniform selection from sorted) complete."
|
||||
)
|
||||
# Note: Dev/Test files are not processed by default with this sampling logic.
|
@ -1 +0,0 @@
|
||||
Subproject commit 7e60ab26825a452f8ee8eb19799d1cb6c1746326
|
Loading…
Reference in new issue