parent
5eabd121a3
commit
89e07bc02d
@ -1,2 +1,3 @@
|
|||||||
HF_TOKEN=<your-huggingface-token>
|
HF_TOKEN=<your-huggingface-token>
|
||||||
OPENROUTER_API_KEY=<your-openrouter-api-key>
|
TAVILY_API_KEY=<your-tavily-api-key>
|
||||||
|
SERPER_API_KEY=<your-serper-api-key>
|
@ -1,3 +0,0 @@
|
|||||||
[submodule "third_party/FlashRAG"]
|
|
||||||
path = third_party/FlashRAG
|
|
||||||
url = https://github.com/RUC-NLPIR/FlashRAG.git
|
|
@ -1,656 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
from copy import deepcopy
|
|
||||||
from datetime import datetime
|
|
||||||
from functools import wraps
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import requests
|
|
||||||
from flashrag.config import Config
|
|
||||||
from flashrag.generator.generator import BaseGenerator
|
|
||||||
from flashrag.pipeline import BasicPipeline
|
|
||||||
from flashrag.retriever.retriever import BaseTextRetriever
|
|
||||||
from flashrag.utils import get_dataset
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
from config import logger
|
|
||||||
from src.agent import Agent, AgenticOutputs
|
|
||||||
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
|
|
||||||
|
|
||||||
|
|
||||||
def retry(max_retries=10, sleep=1):
|
|
||||||
"""Decorator to retry a function with exponential backoff."""
|
|
||||||
|
|
||||||
def decorator(func):
|
|
||||||
@wraps(func)
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
func_name = func.__name__
|
|
||||||
for attempt in range(max_retries):
|
|
||||||
try:
|
|
||||||
result = func(*args, **kwargs)
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Attempt {attempt + 1} of {func_name} failed: {e}")
|
|
||||||
if attempt == max_retries - 1:
|
|
||||||
logger.error(f"Function {func_name} failed after {max_retries} retries.", exc_info=True)
|
|
||||||
raise e
|
|
||||||
backoff_time = sleep * (2**attempt)
|
|
||||||
logger.info(f"Retrying {func_name} in {backoff_time:.2f} seconds...")
|
|
||||||
time.sleep(backoff_time)
|
|
||||||
logger.error(f"Function {func_name} retry logic finished unexpectedly.")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
class RemoteRetriever(BaseTextRetriever):
|
|
||||||
"""A wrapper for remote retriever service with retry logic and logging."""
|
|
||||||
|
|
||||||
def __init__(self, config: Config):
|
|
||||||
"""Initializes the RemoteRetriever."""
|
|
||||||
super().__init__(config)
|
|
||||||
self.remote_url = f"http://{getattr(config, 'remote_retriever_url', 'localhost:8001')}"
|
|
||||||
self.topk = getattr(config, "retriever_topk", 5)
|
|
||||||
logger.info(f"🔗 Remote retriever URL: {self.remote_url}")
|
|
||||||
|
|
||||||
@retry(max_retries=3, sleep=2)
|
|
||||||
def _search(self, query: str, num: int | None = None, return_score: bool = False) -> list[dict]:
|
|
||||||
"""Search for documents using the remote retriever service."""
|
|
||||||
num = num if num is not None else self.topk
|
|
||||||
url = f"{self.remote_url}/search"
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = requests.post(
|
|
||||||
url,
|
|
||||||
json={"query": query, "top_n": num, "return_score": return_score},
|
|
||||||
timeout=30,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
results = response.json()
|
|
||||||
return results
|
|
||||||
except requests.exceptions.Timeout:
|
|
||||||
logger.error(f"Search request timed out after 30 seconds for query: {query[:50]}...")
|
|
||||||
raise
|
|
||||||
except requests.exceptions.ConnectionError:
|
|
||||||
logger.error(f"Could not connect to search service at {url}")
|
|
||||||
raise
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
logger.error(f"Search request failed: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Unexpected search error: {str(e)}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
@retry(max_retries=3, sleep=2)
|
|
||||||
def _batch_search(
|
|
||||||
self, queries: list[str], num: int | None = None, return_score: bool = False
|
|
||||||
) -> list[list[dict]]:
|
|
||||||
"""Batch search for documents using the remote retriever service."""
|
|
||||||
num = num if num is not None else self.topk
|
|
||||||
url = f"{self.remote_url}/batch_search"
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = requests.post(
|
|
||||||
url,
|
|
||||||
json={"query": queries, "top_n": num, "return_score": return_score},
|
|
||||||
timeout=60,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
results = response.json()
|
|
||||||
return results
|
|
||||||
except requests.exceptions.Timeout:
|
|
||||||
logger.error(f"Batch search request timed out after 60 seconds for {len(queries)} queries.")
|
|
||||||
raise
|
|
||||||
except requests.exceptions.ConnectionError:
|
|
||||||
logger.error(f"Could not connect to batch search service at {url}")
|
|
||||||
raise
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
logger.error(f"Batch search request failed: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Unexpected batch search error: {str(e)}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
class ReSearchPipeline(BasicPipeline):
|
|
||||||
"""Pipeline for ReSearch method using Agent for generation and tool use."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self, config: Config, retriever: BaseTextRetriever | None = None, generator: BaseGenerator | None = None
|
|
||||||
):
|
|
||||||
"""Initializes the ReSearchPipeline."""
|
|
||||||
super().__init__(config)
|
|
||||||
logger.info("🔧 Initializing ReSearchPipeline...")
|
|
||||||
|
|
||||||
self.retriever = retriever or RemoteRetriever(config)
|
|
||||||
|
|
||||||
self.generator = generator or SGLRemoteGenerator(config)
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(config.generator_model_path, trust_remote_code=True)
|
|
||||||
if not self.tokenizer.pad_token:
|
|
||||||
logger.warning("Tokenizer does not have a pad token; setting to eos_token.")
|
|
||||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
||||||
self.tokenizer.padding_side = "left"
|
|
||||||
logger.info("✅ Tokenizer initialized.")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to initialize tokenizer: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
tokenizer_name = self.tokenizer.name_or_path.lower()
|
|
||||||
|
|
||||||
if "deepseek-ai/deepseek-r1-distill" in tokenizer_name:
|
|
||||||
adapter = R1DistilTokenizerAdapter()
|
|
||||||
elif "llama" in tokenizer_name:
|
|
||||||
adapter = LlamaTokenizerAdapter()
|
|
||||||
else:
|
|
||||||
logger.warning(f"Unknown tokenizer type '{tokenizer_name}', defaulting to R1DistilTokenizerAdapter.")
|
|
||||||
adapter = R1DistilTokenizerAdapter()
|
|
||||||
logger.info(f"🔩 Using Tokenizer Adapter: {type(adapter).__name__}")
|
|
||||||
|
|
||||||
def retriever_search(query: str, return_type=str, results: int = 5):
|
|
||||||
try:
|
|
||||||
search_results = self.retriever._search(query, num=results)
|
|
||||||
return self.format_search_results(search_results)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error during agent's retriever search for query '{query[:50]}...': {e}", exc_info=True)
|
|
||||||
return "<information>Search failed due to an internal error."
|
|
||||||
|
|
||||||
self.agent = Agent(adapter, search_fn=retriever_search)
|
|
||||||
logger.info("✅ Agent initialized.")
|
|
||||||
logger.info("✅ ReSearchPipeline initialized successfully.")
|
|
||||||
|
|
||||||
def format_search_results(self, search_results: list[dict]) -> str:
|
|
||||||
"""Formats search results into a string for the agent prompt."""
|
|
||||||
if not search_results:
|
|
||||||
return "<information>No results found.</information>"
|
|
||||||
max_content_len = 500
|
|
||||||
formatted = "\n-------\n".join(
|
|
||||||
[
|
|
||||||
f"Result {i + 1}: {r.get('contents', 'N/A')[:max_content_len]}{'...' if len(r.get('contents', '')) > max_content_len else ''}"
|
|
||||||
for i, r in enumerate(search_results)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
formatted_str = f"<information>{formatted}</information>"
|
|
||||||
|
|
||||||
return formatted_str
|
|
||||||
|
|
||||||
def extract_search_query(self, text: str) -> str | None:
|
|
||||||
"""Extract search query from text between <search> tags."""
|
|
||||||
pattern = re.compile(r"<search>(.*?)</search>", re.DOTALL)
|
|
||||||
matches = pattern.findall(text)
|
|
||||||
if matches:
|
|
||||||
query = matches[-1].strip()
|
|
||||||
return query
|
|
||||||
return None
|
|
||||||
|
|
||||||
def extract_answer(self, text: str) -> str | None:
|
|
||||||
"""Extract answer from text between <answer> tags."""
|
|
||||||
pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL)
|
|
||||||
matches = pattern.findall(text)
|
|
||||||
if matches:
|
|
||||||
answer = matches[-1].strip()
|
|
||||||
|
|
||||||
return answer
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def run(self, dataset, do_eval: bool = True, pred_process_fun=None):
|
|
||||||
"""Runs the ReSearch pipeline on the dataset using the Agent."""
|
|
||||||
logger.info(f"🏃 Starting ReSearch pipeline run with {len(dataset)} items...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
questions = [item.question if hasattr(item, "question") else item["question"] for item in dataset]
|
|
||||||
|
|
||||||
except (KeyError, AttributeError, TypeError) as e:
|
|
||||||
logger.error(f"Failed to extract questions from dataset items. Error: {e}", exc_info=True)
|
|
||||||
logger.error("Ensure dataset items have a 'question' key or attribute.")
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
agent_max_generations = getattr(self.config, "agent_max_generations", 32)
|
|
||||||
generator_max_output_len = getattr(self.config, "generator_max_output_len", 24576)
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info(f"🤖 Running agent inference for {len(questions)} questions...")
|
|
||||||
agent_outputs: AgenticOutputs = self.agent.run_agent(
|
|
||||||
generate_fn=self.generator.generate,
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
questions=questions,
|
|
||||||
max_generations=agent_max_generations,
|
|
||||||
max_new_tokens=generator_max_output_len,
|
|
||||||
)
|
|
||||||
final_responses = agent_outputs.final_response_str
|
|
||||||
logger.info(f"✅ Agent inference completed. Received {len(final_responses)} final responses.")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Agent run failed during inference: {e}", exc_info=True)
|
|
||||||
logger.warning("Agent run failed, attempting evaluation with potentially incomplete results.")
|
|
||||||
for item in dataset:
|
|
||||||
if hasattr(item, "update_output"):
|
|
||||||
item.update_output("pred", "AGENT_ERROR")
|
|
||||||
elif isinstance(item, dict):
|
|
||||||
item["pred"] = "AGENT_ERROR"
|
|
||||||
|
|
||||||
logger.info("📝 Extracting answers and updating dataset items...")
|
|
||||||
num_updated = 0
|
|
||||||
num_missing_answers = 0
|
|
||||||
if len(final_responses) == len(dataset):
|
|
||||||
for i, item in enumerate(dataset):
|
|
||||||
response = final_responses[i]
|
|
||||||
answer = self.extract_answer(response)
|
|
||||||
pred_to_save = answer if answer is not None else ""
|
|
||||||
|
|
||||||
if answer is None:
|
|
||||||
num_missing_answers += 1
|
|
||||||
|
|
||||||
if hasattr(item, "update_output"):
|
|
||||||
item.update_output("pred", pred_to_save)
|
|
||||||
item.update_output("final_response", response)
|
|
||||||
num_updated += 1
|
|
||||||
elif isinstance(item, dict):
|
|
||||||
item["pred"] = pred_to_save
|
|
||||||
item["final_response"] = response
|
|
||||||
num_updated += 1
|
|
||||||
else:
|
|
||||||
logger.warning(f"Item {i} has unknown type {type(item)}, cannot update with prediction.")
|
|
||||||
|
|
||||||
logger.info(f"Updated {num_updated}/{len(dataset)} dataset items with predictions.")
|
|
||||||
if num_missing_answers > 0:
|
|
||||||
logger.warning(f"{num_missing_answers} items had no <answer> tag.")
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"Mismatch between dataset size ({len(dataset)}) and number of agent responses ({len(final_responses)}). Cannot reliably update dataset."
|
|
||||||
)
|
|
||||||
for item in dataset:
|
|
||||||
if hasattr(item, "update_output"):
|
|
||||||
item.update_output("pred", "RESPONSE_COUNT_MISMATCH")
|
|
||||||
elif isinstance(item, dict):
|
|
||||||
item["pred"] = "RESPONSE_COUNT_MISMATCH"
|
|
||||||
|
|
||||||
if do_eval:
|
|
||||||
logger.info("📊 Evaluating results using BasicPipeline.evaluate...")
|
|
||||||
try:
|
|
||||||
dataset = self.evaluate(dataset, do_eval=True, pred_process_fun=pred_process_fun)
|
|
||||||
logger.info("✅ Evaluation completed via base class method.")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error during BasicPipeline.evaluate: {e}", exc_info=True)
|
|
||||||
logger.warning("Evaluation may be incomplete.")
|
|
||||||
else:
|
|
||||||
logger.info("Skipping evaluation step as do_eval=False.")
|
|
||||||
|
|
||||||
logger.info("✅ ReSearch pipeline run finished.")
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
|
|
||||||
class SGLRemoteGenerator(BaseGenerator):
|
|
||||||
"""Class for decoder-only generator, based on SGLang remote service."""
|
|
||||||
|
|
||||||
def __init__(self, config: Config):
|
|
||||||
"""Initializes the SGLRemoteGenerator."""
|
|
||||||
super().__init__(config)
|
|
||||||
logger.info("🔧 Initializing SGLRemoteGenerator...")
|
|
||||||
sgl_url = getattr(config, "sgl_remote_url", "localhost:8002")
|
|
||||||
self.sgl_remote_url = f"http://{sgl_url}/generate"
|
|
||||||
self.health_check_url = f"http://{sgl_url}/health"
|
|
||||||
logger.info(f"🔗 Remote Generator URL: {self.sgl_remote_url}")
|
|
||||||
self.model_path = getattr(config, "generator_model_path", None)
|
|
||||||
if not self.model_path:
|
|
||||||
logger.error("generator_model_path not found in config!")
|
|
||||||
raise ValueError("generator_model_path is required for SGLRemoteGenerator")
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
|
|
||||||
logger.info("✅ Tokenizer loaded for generator.")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to load tokenizer for generator from {self.model_path}: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
self.generation_params = getattr(config, "generation_params", {})
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
self._check_health()
|
|
||||||
|
|
||||||
def _check_health(self):
|
|
||||||
"""Checks the health of the remote generator service."""
|
|
||||||
try:
|
|
||||||
test_response = requests.get(self.health_check_url, timeout=5)
|
|
||||||
test_response.raise_for_status()
|
|
||||||
logger.info("✅ Remote generator service is available")
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
logger.error(f"Could not connect or verify remote generator service at {self.health_check_url}: {str(e)}")
|
|
||||||
logger.warning("Please ensure the SGLang service is running and accessible.")
|
|
||||||
|
|
||||||
@retry(max_retries=5, sleep=2)
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
input_list: list[str] | str,
|
|
||||||
return_raw_output: bool = False,
|
|
||||||
return_scores: bool = False,
|
|
||||||
**params,
|
|
||||||
) -> list[str] | tuple[list[str], list[list[float]]] | list[dict]:
|
|
||||||
"""Generates text using the remote SGLang service."""
|
|
||||||
if isinstance(input_list, str):
|
|
||||||
input_list = [input_list]
|
|
||||||
if not isinstance(input_list, list) or not all(isinstance(item, str) for item in input_list):
|
|
||||||
raise ValueError("Input must be a string or a list of strings.")
|
|
||||||
|
|
||||||
batch_size = len(input_list)
|
|
||||||
data_to_remote = {"text": input_list}
|
|
||||||
|
|
||||||
effective_params = deepcopy(self.generation_params)
|
|
||||||
effective_params.update(params)
|
|
||||||
|
|
||||||
curr_sampling_params = {}
|
|
||||||
if effective_params.get("do_sample", True) is False:
|
|
||||||
curr_sampling_params["temperature"] = 0.0
|
|
||||||
else:
|
|
||||||
curr_sampling_params["temperature"] = effective_params.get(
|
|
||||||
"temperature", getattr(self.config, "temperature", 0.7)
|
|
||||||
)
|
|
||||||
|
|
||||||
default_max_tokens = getattr(self.config, "generator_max_output_len", 1024)
|
|
||||||
curr_sampling_params["max_new_tokens"] = effective_params.get("max_new_tokens", default_max_tokens)
|
|
||||||
|
|
||||||
stop_sequences = effective_params.get("stop", [])
|
|
||||||
if isinstance(stop_sequences, str):
|
|
||||||
stop_sequences = [stop_sequences]
|
|
||||||
if stop_sequences:
|
|
||||||
curr_sampling_params["stop"] = stop_sequences
|
|
||||||
|
|
||||||
keys_to_remove = ["do_sample", "temperature", "max_new_tokens", "stop"]
|
|
||||||
for key in keys_to_remove:
|
|
||||||
effective_params.pop(key, None)
|
|
||||||
|
|
||||||
if "top_p" in effective_params:
|
|
||||||
curr_sampling_params["top_p"] = effective_params["top_p"]
|
|
||||||
if "top_k" in effective_params:
|
|
||||||
curr_sampling_params["top_k"] = effective_params["top_k"]
|
|
||||||
|
|
||||||
data_to_remote["sampling_params"] = curr_sampling_params
|
|
||||||
|
|
||||||
if return_scores:
|
|
||||||
data_to_remote["return_logprob"] = True
|
|
||||||
data_to_remote["top_logprobs_num"] = getattr(self.config, "top_logprobs_num", 2)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = requests.post(
|
|
||||||
self.sgl_remote_url, json=data_to_remote, timeout=120, headers={"Content-Type": "application/json"}
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
response_list = response.json()
|
|
||||||
|
|
||||||
if return_raw_output:
|
|
||||||
return response_list
|
|
||||||
|
|
||||||
generated_text = []
|
|
||||||
for item in response_list:
|
|
||||||
text = item.get("text", "")
|
|
||||||
finish_reason = item.get("meta_info", {}).get("finish_reason", {})
|
|
||||||
matched_stop = finish_reason.get("matched")
|
|
||||||
if matched_stop and curr_sampling_params.get("stop") and matched_stop in curr_sampling_params["stop"]:
|
|
||||||
text += matched_stop
|
|
||||||
generated_text.append(text)
|
|
||||||
|
|
||||||
if return_scores:
|
|
||||||
scores = []
|
|
||||||
for resp_item in response_list:
|
|
||||||
logprobs_list = resp_item.get("meta_info", {}).get("output_token_logprobs", [])
|
|
||||||
token_scores = [
|
|
||||||
np.exp(logprob[0]) if (logprob and len(logprob) > 0) else 0.0 for logprob in logprobs_list
|
|
||||||
]
|
|
||||||
scores.append(token_scores)
|
|
||||||
return generated_text, scores
|
|
||||||
else:
|
|
||||||
return generated_text
|
|
||||||
|
|
||||||
except requests.exceptions.Timeout:
|
|
||||||
logger.error("Generation request timed out after 120 seconds.")
|
|
||||||
raise
|
|
||||||
except requests.exceptions.ConnectionError:
|
|
||||||
logger.error(f"Could not connect to remote generator service at {self.sgl_remote_url}.")
|
|
||||||
raise
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
logger.error(f"Network error during generation: {str(e)}", exc_info=True)
|
|
||||||
raise
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
response_text = "Unknown (error occurred before response object assignment)"
|
|
||||||
if "response" in locals() and hasattr(response, "text"):
|
|
||||||
response_text = response.text[:500]
|
|
||||||
logger.error(
|
|
||||||
f"Failed to decode JSON response from {self.sgl_remote_url}. Response text: {response_text}...",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Unexpected error during generation: {str(e)}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
def load_dataset_items(config: Config, split: str) -> list[dict | object]:
|
|
||||||
"""Loads dataset items using flashrag's get_dataset."""
|
|
||||||
logger.info(f"📚 Loading dataset: {config.dataset_name}, Split: {split}")
|
|
||||||
try:
|
|
||||||
all_splits = get_dataset(config)
|
|
||||||
if split not in all_splits:
|
|
||||||
logger.error(
|
|
||||||
f"Split '{split}' not found in dataset '{config.dataset_name}'. Available splits: {list(all_splits.keys())}"
|
|
||||||
)
|
|
||||||
return []
|
|
||||||
dataset_items = all_splits[split]
|
|
||||||
logger.info(f"Successfully loaded {len(dataset_items)} items for split '{split}'.")
|
|
||||||
|
|
||||||
return dataset_items
|
|
||||||
except FileNotFoundError:
|
|
||||||
logger.error(
|
|
||||||
f"Dataset files not found for '{config.dataset_name}' in '{config.data_dir}'. Check config and paths."
|
|
||||||
)
|
|
||||||
return []
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error loading dataset using get_dataset: {e}", exc_info=True)
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def save_results(args: argparse.Namespace, config: Config, result_dataset, run_duration: float):
|
|
||||||
"""Saves summary and debug information."""
|
|
||||||
logger.info("💾 Saving results...")
|
|
||||||
summary_file = os.path.join(args.save_dir, f"{args.save_note}_summary.txt")
|
|
||||||
debug_file = os.path.join(args.save_dir, f"{args.save_note}_debug.json")
|
|
||||||
|
|
||||||
num_items = len(result_dataset)
|
|
||||||
|
|
||||||
logger.info(f"Saving summary results to {summary_file}...")
|
|
||||||
try:
|
|
||||||
with open(summary_file, "w", encoding="utf-8") as f:
|
|
||||||
f.write("EVALUATION SUMMARY\n")
|
|
||||||
f.write("=================\n\n")
|
|
||||||
f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
|
||||||
f.write(f"Run Duration: {run_duration:.2f} seconds\n")
|
|
||||||
f.write(f"Dataset: {config.dataset_name} ({args.split} split)\n")
|
|
||||||
f.write(f"Model: {config.generator_model_path}\n")
|
|
||||||
f.write(f"Retriever: {config.remote_retriever_url}\n")
|
|
||||||
f.write(f"Agent Max Generations: {getattr(config, 'agent_max_generations', 'N/A')}\n")
|
|
||||||
f.write(f"Generator Max Output Len: {getattr(config, 'generator_max_output_len', 'N/A')}\n\n")
|
|
||||||
f.write(f"Total items processed: {num_items}\n")
|
|
||||||
f.write("\nNote: Verification was skipped in this run.\n")
|
|
||||||
f.write("Note: Overall metrics (like EM, F1) are usually printed to console by evaluate method.\n")
|
|
||||||
|
|
||||||
logger.info(f"✅ Summary saved to {summary_file}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error saving summary file '{summary_file}': {e}", exc_info=True)
|
|
||||||
|
|
||||||
logger.info(f"Saving debug information (predictions & responses) to {debug_file}...")
|
|
||||||
try:
|
|
||||||
debug_data = []
|
|
||||||
for i, item in enumerate(result_dataset):
|
|
||||||
item_data: dict[str, object] = {}
|
|
||||||
|
|
||||||
def get_item_value(data_item, key_or_attr: str) -> str | int | float | list | bool | None:
|
|
||||||
if isinstance(data_item, dict):
|
|
||||||
return data_item.get(key_or_attr)
|
|
||||||
elif hasattr(data_item, key_or_attr):
|
|
||||||
return getattr(data_item, key_or_attr)
|
|
||||||
return None
|
|
||||||
|
|
||||||
item_data["item_index"] = i
|
|
||||||
item_data["question"] = get_item_value(item, "question")
|
|
||||||
item_data["prediction"] = get_item_value(item, "pred")
|
|
||||||
item_data["final_response"] = get_item_value(item, "final_response")
|
|
||||||
|
|
||||||
gt_answer_val = None
|
|
||||||
try:
|
|
||||||
gt_answer_val = get_item_value(item, "answer")
|
|
||||||
if gt_answer_val is None:
|
|
||||||
answers_list = get_item_value(item, "answers")
|
|
||||||
if isinstance(answers_list, list) and answers_list:
|
|
||||||
raw_ans = answers_list[0]
|
|
||||||
if isinstance(raw_ans, (str, int, float, bool)):
|
|
||||||
gt_answer_val = raw_ans
|
|
||||||
else:
|
|
||||||
gt_answer_val = str(raw_ans)
|
|
||||||
elif not isinstance(gt_answer_val, (str, int, float, bool)):
|
|
||||||
gt_answer_val = str(gt_answer_val)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Could not safely get ground truth for item {i}: {e}")
|
|
||||||
gt_answer_val = "ERROR_GETTING_ANSWER"
|
|
||||||
item_data["ground_truth"] = gt_answer_val
|
|
||||||
|
|
||||||
eval_score_val = None
|
|
||||||
try:
|
|
||||||
eval_score_val = get_item_value(item, "score")
|
|
||||||
if not isinstance(eval_score_val, (str, int, float, bool, type(None))):
|
|
||||||
eval_score_val = str(eval_score_val)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Could not safely get score for item {i}: {e}")
|
|
||||||
eval_score_val = "ERROR_GETTING_SCORE"
|
|
||||||
item_data["eval_score"] = eval_score_val
|
|
||||||
|
|
||||||
debug_data.append(item_data)
|
|
||||||
|
|
||||||
with open(debug_file, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(debug_data, f, indent=2, ensure_ascii=False)
|
|
||||||
logger.info(f"✅ Debug information saved to {debug_file}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error saving debug file '{debug_file}': {e}", exc_info=True)
|
|
||||||
|
|
||||||
|
|
||||||
def research(args: argparse.Namespace, config: Config):
|
|
||||||
"""Main function to run the research evaluation pipeline."""
|
|
||||||
logger.info("🚀 Starting research pipeline execution...")
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
test_data = load_dataset_items(config, args.split)
|
|
||||||
if not test_data:
|
|
||||||
logger.error("Failed to load test data. Exiting.")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info("🏗️ Building ReSearchPipeline...")
|
|
||||||
pipeline = ReSearchPipeline(config)
|
|
||||||
logger.info("✅ Pipeline built successfully.")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to initialize ReSearchPipeline: {e}", exc_info=True)
|
|
||||||
return
|
|
||||||
|
|
||||||
agent_max_generations = getattr(config, "agent_max_generations", 32)
|
|
||||||
generator_max_output_len = getattr(config, "generator_max_output_len", 24576)
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info("🏃 Starting pipeline run...")
|
|
||||||
result_dataset = pipeline.run(test_data, do_eval=True)
|
|
||||||
logger.info("✅ Pipeline run completed.")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error during pipeline run: {e}", exc_info=True)
|
|
||||||
result_dataset = test_data
|
|
||||||
logger.warning("Pipeline run failed, attempting to save inputs/partial results.")
|
|
||||||
|
|
||||||
run_duration = time.time() - start_time
|
|
||||||
logger.info(f"Total run duration: {run_duration:.2f} seconds.")
|
|
||||||
save_results(args, config, result_dataset, run_duration)
|
|
||||||
|
|
||||||
logger.info("🏁 Research pipeline execution finished.")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Running ReSearch Evaluation Pipeline")
|
|
||||||
parser.add_argument(
|
|
||||||
"--config_path", type=str, default="./eval_config.yaml", help="Path to the main FlashRAG config file."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--dataset_name",
|
|
||||||
type=str,
|
|
||||||
default="bamboogle",
|
|
||||||
help="Name of the dataset (must match config or data_dir structure).",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--split", type=str, default="test", help="Dataset split to evaluate (e.g., test, validation)."
|
|
||||||
)
|
|
||||||
parser.add_argument("--save_dir", type=str, default="./output_logs", help="Directory to save logs and results.")
|
|
||||||
parser.add_argument("--save_note", type=str, default="research_run", help="A note to prepend to saved filenames.")
|
|
||||||
|
|
||||||
parser.add_argument("--data_dir", type=str, help="Override data directory specified in config.")
|
|
||||||
parser.add_argument(
|
|
||||||
"--sgl_remote_url", type=str, help="Override SGLang remote generator URL (e.g., localhost:8002)."
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--remote_retriever_url", type=str, help="Override remote retriever URL (e.g., localhost:8001)."
|
|
||||||
)
|
|
||||||
parser.add_argument("--generator_model_path", type=str, help="Override generator model path specified in config.")
|
|
||||||
parser.add_argument("--retriever_topk", type=int, help="Override retriever top K.")
|
|
||||||
parser.add_argument("--generator_max_output_len", type=int, help="Override generator max output length.")
|
|
||||||
parser.add_argument("--agent_max_generations", type=int, help="Override agent max interaction turns.")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
logger.info(f"Starting evaluation script with arguments: {args}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
os.makedirs(args.save_dir, exist_ok=True)
|
|
||||||
logger.info(f"💾 Logs and results will be saved to: {args.save_dir}")
|
|
||||||
except OSError as e:
|
|
||||||
logger.error(f"Could not create save directory '{args.save_dir}': {e}", exc_info=True)
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
config_overrides = {
|
|
||||||
k: v
|
|
||||||
for k, v in vars(args).items()
|
|
||||||
if v is not None
|
|
||||||
and k
|
|
||||||
not in [
|
|
||||||
"config_path",
|
|
||||||
"dataset_name",
|
|
||||||
"split",
|
|
||||||
"save_dir",
|
|
||||||
"save_note",
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(f"🔧 Loading configuration from: {args.config_path}")
|
|
||||||
try:
|
|
||||||
config = Config(args.config_path, config_dict=config_overrides)
|
|
||||||
config.dataset_name = args.dataset_name
|
|
||||||
if args.data_dir:
|
|
||||||
config.data_dir = args.data_dir
|
|
||||||
|
|
||||||
logger.info(f"Effective data_dir: {getattr(config, 'data_dir', 'N/A')}")
|
|
||||||
logger.info(f"Effective generator_model_path: {getattr(config, 'generator_model_path', 'N/A')}")
|
|
||||||
logger.info(f"Effective sgl_remote_url: {getattr(config, 'sgl_remote_url', 'N/A')}")
|
|
||||||
logger.info(f"Effective remote_retriever_url: {getattr(config, 'remote_retriever_url', 'N/A')}")
|
|
||||||
|
|
||||||
logger.info("✅ Config loaded and potentially overridden by CLI arguments.")
|
|
||||||
|
|
||||||
config["dataset_path"] = os.path.join(config.data_dir, config.dataset_name)
|
|
||||||
|
|
||||||
except FileNotFoundError:
|
|
||||||
logger.error(f"Config file not found at '{args.config_path}'. Please check the path.")
|
|
||||||
exit(1)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error loading or processing configuration: {e}", exc_info=True)
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
research(args, config)
|
|
@ -1,68 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import os
|
|
||||||
import zipfile
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
from config import DATA_DIR
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
|
||||||
"""Parse command line arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
argparse.Namespace: Parsed arguments
|
|
||||||
"""
|
|
||||||
parser = argparse.ArgumentParser(description="Download FlashRAG datasets from HuggingFace Hub")
|
|
||||||
parser.add_argument(
|
|
||||||
"--repo-id",
|
|
||||||
type=str,
|
|
||||||
default="RUC-NLPIR/FlashRAG_datasets",
|
|
||||||
help="HuggingFace repository IDs",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--local-dir",
|
|
||||||
type=str,
|
|
||||||
default=DATA_DIR / "flashrag_datasets",
|
|
||||||
help="Local directory to save model",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Main function to download model."""
|
|
||||||
args = parse_args()
|
|
||||||
load_dotenv(override=True)
|
|
||||||
|
|
||||||
# Configuration
|
|
||||||
HF_TOKEN = os.getenv("HF_TOKEN")
|
|
||||||
|
|
||||||
ALLOW_PATTERNS = [
|
|
||||||
"*retrieval-corpus*",
|
|
||||||
"*bamboogle*",
|
|
||||||
"*nq*",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Download the model
|
|
||||||
snapshot_download(
|
|
||||||
token=HF_TOKEN,
|
|
||||||
repo_id=args.repo_id,
|
|
||||||
local_dir=args.local_dir,
|
|
||||||
repo_type="dataset",
|
|
||||||
# ignore_patterns=IGNORE_PATTERNS,
|
|
||||||
allow_patterns=ALLOW_PATTERNS,
|
|
||||||
)
|
|
||||||
|
|
||||||
# unzip data/flashrag_datasets/retrieval-corpus/wiki18_100w.zip
|
|
||||||
print("Unzipping wiki18_100w.zip. Might take a while...")
|
|
||||||
zip_file_path = os.path.join(args.local_dir, "retrieval-corpus", "wiki18_100w.zip")
|
|
||||||
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
|
|
||||||
zip_ref.extractall(args.local_dir)
|
|
||||||
|
|
||||||
print(f"✅ Done: {args.repo_id} -> {args.local_dir}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -1,43 +0,0 @@
|
|||||||
"""Download and extract FlashRAG index."""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import zipfile
|
|
||||||
|
|
||||||
import requests
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from config import DATA_DIR
|
|
||||||
|
|
||||||
# Constants
|
|
||||||
URL = "https://www.modelscope.cn/datasets/hhjinjiajie/FlashRAG_Dataset/resolve/master/retrieval_corpus/wiki18_100w_e5_index.zip"
|
|
||||||
ZIP_NAME = "wiki18_100w_e5_index.zip"
|
|
||||||
zip_path = DATA_DIR / ZIP_NAME
|
|
||||||
|
|
||||||
# Download with progress bar
|
|
||||||
print("📥 Downloading index...")
|
|
||||||
response = requests.get(URL, stream=True)
|
|
||||||
total_size = int(response.headers.get("content-length", 0))
|
|
||||||
|
|
||||||
with (
|
|
||||||
open(zip_path, "wb") as f,
|
|
||||||
tqdm(
|
|
||||||
desc=ZIP_NAME,
|
|
||||||
total=total_size,
|
|
||||||
unit="iB",
|
|
||||||
unit_scale=True,
|
|
||||||
unit_divisor=1024,
|
|
||||||
) as bar,
|
|
||||||
):
|
|
||||||
for data in response.iter_content(chunk_size=1024):
|
|
||||||
size = f.write(data)
|
|
||||||
bar.update(size)
|
|
||||||
|
|
||||||
# Extract
|
|
||||||
print("📦 Extracting index...")
|
|
||||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
|
||||||
zip_ref.extractall(DATA_DIR)
|
|
||||||
|
|
||||||
# Clean up zip
|
|
||||||
os.remove(zip_path)
|
|
||||||
print("✅ Download and extraction completed successfully!")
|
|
||||||
print(f"Index file is at: {DATA_DIR}/data00/jiajie_jin/flashrag_indexes/wiki_dpr_100w/e5_flat_inner.index")
|
|
@ -1,54 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import os
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
from config import GENERATOR_MODEL_DIR, GENERATOR_MODEL_REPO_ID
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
|
||||||
"""Parse command line arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
argparse.Namespace: Parsed arguments
|
|
||||||
"""
|
|
||||||
parser = argparse.ArgumentParser(description="Download model from HuggingFace Hub")
|
|
||||||
parser.add_argument(
|
|
||||||
"--repo-id",
|
|
||||||
type=str,
|
|
||||||
default=GENERATOR_MODEL_REPO_ID,
|
|
||||||
help="HuggingFace repository ID",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--local-dir",
|
|
||||||
type=str,
|
|
||||||
default=GENERATOR_MODEL_DIR,
|
|
||||||
help="Local directory to save model",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Main function to download model."""
|
|
||||||
args = parse_args()
|
|
||||||
load_dotenv(override=True)
|
|
||||||
|
|
||||||
# Configuration
|
|
||||||
HF_TOKEN = os.getenv("HF_TOKEN")
|
|
||||||
|
|
||||||
print("Downloading model to", args.local_dir)
|
|
||||||
|
|
||||||
# Download the model
|
|
||||||
snapshot_download(
|
|
||||||
token=HF_TOKEN,
|
|
||||||
repo_id=args.repo_id,
|
|
||||||
local_dir=args.local_dir,
|
|
||||||
repo_type="model",
|
|
||||||
)
|
|
||||||
print(f"✅ Done: {args.repo_id} -> {args.local_dir}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -1,54 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import os
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
|
|
||||||
from config import RETRIEVER_MODEL_DIR, RETRIEVER_MODEL_REPO_ID
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
|
||||||
"""Parse command line arguments.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
argparse.Namespace: Parsed arguments
|
|
||||||
"""
|
|
||||||
parser = argparse.ArgumentParser(description="Download model from HuggingFace Hub")
|
|
||||||
parser.add_argument(
|
|
||||||
"--repo-id",
|
|
||||||
type=str,
|
|
||||||
default=RETRIEVER_MODEL_REPO_ID,
|
|
||||||
help="HuggingFace repository ID",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--local-dir",
|
|
||||||
type=str,
|
|
||||||
default=RETRIEVER_MODEL_DIR,
|
|
||||||
help="Local directory to save model",
|
|
||||||
)
|
|
||||||
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""Main function to download model."""
|
|
||||||
args = parse_args()
|
|
||||||
load_dotenv(override=True)
|
|
||||||
|
|
||||||
# Configuration
|
|
||||||
HF_TOKEN = os.getenv("HF_TOKEN")
|
|
||||||
|
|
||||||
print("Downloading model to", args.local_dir)
|
|
||||||
|
|
||||||
# Download the model
|
|
||||||
snapshot_download(
|
|
||||||
token=HF_TOKEN,
|
|
||||||
repo_id=args.repo_id,
|
|
||||||
local_dir=args.local_dir,
|
|
||||||
repo_type="model",
|
|
||||||
)
|
|
||||||
print(f"✅ Done: {args.repo_id} -> {args.local_dir}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
@ -1,9 +0,0 @@
|
|||||||
# ------------------------------------------------Environment Settings------------------------------------------------#
|
|
||||||
gpu_id: "0"
|
|
||||||
|
|
||||||
# -------------------------------------------------Retrieval Settings------------------------------------------------#
|
|
||||||
# If set the name, the model path will be find in global paths
|
|
||||||
retrieval_method: "e5" # name or path of the retrieval model.
|
|
||||||
index_path: "/mnt/nas/thinhlpg/data/data00/jiajie_jin/flashrag_indexes/wiki_dpr_100w/e5_flat_inner.index" # path to the indexed file
|
|
||||||
faiss_gpu: False # whether use gpu to hold index
|
|
||||||
corpus_path: "/mnt/nas/thinhlpg/code/DeepSearch/data/flashrag_datasets/wiki18_100w.jsonl" # path to corpus in '.jsonl' format that store the documents
|
|
@ -1,127 +0,0 @@
|
|||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Add project root to sys.path to allow importing config
|
|
||||||
# Assuming the script is at DeepSearch/scripts/serving/serve_generator.py
|
|
||||||
# The project root (DeepSearch) is parents[2]
|
|
||||||
PROJ_ROOT = Path(__file__).resolve().parents[2]
|
|
||||||
if str(PROJ_ROOT) not in sys.path:
|
|
||||||
sys.path.append(str(PROJ_ROOT))
|
|
||||||
|
|
||||||
# Import after adjusting sys.path
|
|
||||||
try:
|
|
||||||
from config import (
|
|
||||||
GENERATOR_MODEL_REPO_ID,
|
|
||||||
GENERATOR_SERVER_PORT,
|
|
||||||
MODEL_CONFIG,
|
|
||||||
logger,
|
|
||||||
)
|
|
||||||
except ImportError as e:
|
|
||||||
# Use print here as logger might not be available if import failed
|
|
||||||
print(
|
|
||||||
f"Error importing config: {e}. Make sure config.py is in the project root ({PROJ_ROOT}) and added to sys.path."
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
def launch_sglang_server(
|
|
||||||
model_id: str,
|
|
||||||
port: int,
|
|
||||||
context_length: int,
|
|
||||||
host: str = "0.0.0.0",
|
|
||||||
dtype: str = "bfloat16",
|
|
||||||
) -> None:
|
|
||||||
"""Launches the SGLang server using specified configurations.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_id: The Hugging Face repository ID of the model.
|
|
||||||
port: The port number for the server.
|
|
||||||
context_length: The maximum context length for the model.
|
|
||||||
host: The host address for the server.
|
|
||||||
dtype: The data type for the model (e.g., 'bfloat16', 'float16').
|
|
||||||
"""
|
|
||||||
command = [
|
|
||||||
sys.executable, # Use the current Python interpreter
|
|
||||||
"-m",
|
|
||||||
"sglang.launch_server",
|
|
||||||
"--model-path",
|
|
||||||
model_id,
|
|
||||||
"--context-length",
|
|
||||||
str(context_length),
|
|
||||||
"--enable-metrics",
|
|
||||||
"--dtype",
|
|
||||||
dtype,
|
|
||||||
"--host",
|
|
||||||
host,
|
|
||||||
"--port",
|
|
||||||
str(port),
|
|
||||||
"--mem-fraction-static",
|
|
||||||
"0.5",
|
|
||||||
"--trust-remote-code",
|
|
||||||
# Recommended by SGLang for stability sometimes
|
|
||||||
"--disable-overlap",
|
|
||||||
# Can sometimes cause issues
|
|
||||||
"--disable-radix-cache",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Log the command clearly
|
|
||||||
command_str = " ".join(command)
|
|
||||||
logger.info(f"🚀 Launching SGLang server with command: {command_str}")
|
|
||||||
|
|
||||||
process = None # Initialize process to None
|
|
||||||
try:
|
|
||||||
# Use Popen to start the server process
|
|
||||||
# It runs in the foreground relative to this script,
|
|
||||||
# but allows us to catch KeyboardInterrupt cleanly.
|
|
||||||
process = subprocess.Popen(command)
|
|
||||||
# Wait for the process to complete (e.g., user interruption)
|
|
||||||
process.wait()
|
|
||||||
# Check return code after waiting
|
|
||||||
if process.returncode != 0:
|
|
||||||
logger.error(f"💥 SGLang server process exited with error code: {process.returncode}")
|
|
||||||
sys.exit(process.returncode)
|
|
||||||
else:
|
|
||||||
logger.info("✅ SGLang server process finished gracefully.")
|
|
||||||
|
|
||||||
except FileNotFoundError:
|
|
||||||
logger.error(f"💥 Error: Python executable or sglang module not found.")
|
|
||||||
logger.error(f"Ensure '{sys.executable}' is correct and sglang is installed.")
|
|
||||||
sys.exit(1)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
logger.info("🛑 SGLang server launch interrupted by user. Stopping server...")
|
|
||||||
# Attempt to terminate the process gracefully
|
|
||||||
if process and process.poll() is None: # Check if process exists and is running
|
|
||||||
process.terminate()
|
|
||||||
try:
|
|
||||||
process.wait(timeout=5) # Wait a bit for termination
|
|
||||||
logger.info("✅ Server terminated gracefully.")
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
logger.warning("⚠️ Server did not terminate gracefully, forcing kill.")
|
|
||||||
process.kill()
|
|
||||||
sys.exit(0) # Exit cleanly after interrupt
|
|
||||||
except Exception as e:
|
|
||||||
# Catch any other unexpected exceptions during launch or waiting
|
|
||||||
logger.error(f"🚨 An unexpected error occurred: {e}")
|
|
||||||
# Ensure process is cleaned up if it exists
|
|
||||||
if process and process.poll() is None:
|
|
||||||
process.kill()
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Get context length from config, default to 8192
|
|
||||||
context_len = MODEL_CONFIG.get("max_seq_length", 8192)
|
|
||||||
|
|
||||||
logger.info("----------------------------------------------------")
|
|
||||||
logger.info("✨ Starting SGLang Generator Server ✨")
|
|
||||||
logger.info(f" Model ID: {GENERATOR_MODEL_REPO_ID}")
|
|
||||||
logger.info(f" Port: {GENERATOR_SERVER_PORT}")
|
|
||||||
logger.info(f" Context Length: {context_len}")
|
|
||||||
logger.info("----------------------------------------------------")
|
|
||||||
|
|
||||||
launch_sglang_server(
|
|
||||||
model_id=GENERATOR_MODEL_REPO_ID,
|
|
||||||
port=GENERATOR_SERVER_PORT,
|
|
||||||
context_length=context_len,
|
|
||||||
)
|
|
@ -1,113 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import asyncio
|
|
||||||
from collections import deque
|
|
||||||
from typing import List, Tuple, Union
|
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException
|
|
||||||
from flashrag.config import Config
|
|
||||||
from flashrag.utils import get_retriever
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from config import RETRIEVER_SERVER_PORT
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
retriever_list = []
|
|
||||||
available_retrievers = deque()
|
|
||||||
retriever_semaphore = None
|
|
||||||
|
|
||||||
|
|
||||||
def init_retriever(args):
|
|
||||||
global retriever_semaphore
|
|
||||||
config = Config(args.config)
|
|
||||||
for i in range(args.num_retriever):
|
|
||||||
print(f"Initializing retriever {i + 1}/{args.num_retriever}")
|
|
||||||
retriever = get_retriever(config)
|
|
||||||
retriever_list.append(retriever)
|
|
||||||
available_retrievers.append(i)
|
|
||||||
# create a semaphore to limit the number of retrievers that can be used concurrently
|
|
||||||
retriever_semaphore = asyncio.Semaphore(args.num_retriever)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
|
||||||
async def health_check():
|
|
||||||
return {"status": "healthy", "retrievers": {"total": len(retriever_list), "available": len(available_retrievers)}}
|
|
||||||
|
|
||||||
|
|
||||||
class QueryRequest(BaseModel):
|
|
||||||
query: str
|
|
||||||
top_n: int = 10
|
|
||||||
return_score: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class BatchQueryRequest(BaseModel):
|
|
||||||
query: List[str]
|
|
||||||
top_n: int = 10
|
|
||||||
return_score: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class Document(BaseModel):
|
|
||||||
id: str
|
|
||||||
contents: str
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/search", response_model=Union[Tuple[List[Document], List[float]], List[Document]])
|
|
||||||
async def search(request: QueryRequest):
|
|
||||||
query = request.query
|
|
||||||
top_n = request.top_n
|
|
||||||
return_score = request.return_score
|
|
||||||
|
|
||||||
if not query or not query.strip():
|
|
||||||
print(f"Query content cannot be empty: {query}")
|
|
||||||
raise HTTPException(status_code=400, detail="Query content cannot be empty")
|
|
||||||
|
|
||||||
async with retriever_semaphore:
|
|
||||||
retriever_idx = available_retrievers.popleft()
|
|
||||||
try:
|
|
||||||
if return_score:
|
|
||||||
results, scores = retriever_list[retriever_idx].search(query, top_n, return_score)
|
|
||||||
return [Document(id=result["id"], contents=result["contents"]) for result in results], scores
|
|
||||||
else:
|
|
||||||
results = retriever_list[retriever_idx].search(query, top_n, return_score)
|
|
||||||
return [Document(id=result["id"], contents=result["contents"]) for result in results]
|
|
||||||
finally:
|
|
||||||
available_retrievers.append(retriever_idx)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/batch_search", response_model=Union[List[List[Document]], Tuple[List[List[Document]], List[List[float]]]])
|
|
||||||
async def batch_search(request: BatchQueryRequest):
|
|
||||||
query = request.query
|
|
||||||
top_n = request.top_n
|
|
||||||
return_score = request.return_score
|
|
||||||
|
|
||||||
async with retriever_semaphore:
|
|
||||||
retriever_idx = available_retrievers.popleft()
|
|
||||||
try:
|
|
||||||
if return_score:
|
|
||||||
results, scores = retriever_list[retriever_idx].batch_search(query, top_n, return_score)
|
|
||||||
return [
|
|
||||||
[Document(id=result["id"], contents=result["contents"]) for result in results[i]]
|
|
||||||
for i in range(len(results))
|
|
||||||
], scores
|
|
||||||
else:
|
|
||||||
results = retriever_list[retriever_idx].batch_search(query, top_n, return_score)
|
|
||||||
return [
|
|
||||||
[Document(id=result["id"], contents=result["contents"]) for result in results[i]]
|
|
||||||
for i in range(len(results))
|
|
||||||
]
|
|
||||||
finally:
|
|
||||||
available_retrievers.append(retriever_idx)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--config", type=str, default="./retriever_config.yaml")
|
|
||||||
parser.add_argument("--num_retriever", type=int, default=1)
|
|
||||||
parser.add_argument("--port", type=int, default=RETRIEVER_SERVER_PORT)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
init_retriever(args)
|
|
||||||
|
|
||||||
import uvicorn
|
|
||||||
|
|
||||||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
|
@ -1,135 +0,0 @@
|
|||||||
import json
|
|
||||||
import math # Import math for ceiling division
|
|
||||||
import sys
|
|
||||||
import traceback # Import traceback
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
# Add project root to Python path if needed (adjust relative path as necessary)
|
|
||||||
project_root = Path(__file__).resolve().parent.parent
|
|
||||||
sys.path.append(str(project_root))
|
|
||||||
|
|
||||||
from src.embeddings import CustomHuggingFaceEmbeddings
|
|
||||||
|
|
||||||
# Import FAISS after potentially adding to sys.path
|
|
||||||
try:
|
|
||||||
from langchain_community.vectorstores import FAISS
|
|
||||||
except ImportError:
|
|
||||||
print("Error: langchain_community or FAISS not installed. Please install with 'pip install langchain faiss-cpu'")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
def build_faiss_index_from_csv(csv_path: str, index_save_path: str, batch_size: int = 128) -> None:
|
|
||||||
"""Builds a FAISS index from a CSV containing paragraph content and metadata.
|
|
||||||
|
|
||||||
Reads a CSV file, generates embeddings for the 'content' column in batches,
|
|
||||||
and saves the FAISS index files (index.faiss, index.pkl) locally.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
csv_path: Path to the input CSV file (e.g., data/processed/paragraphs.csv).
|
|
||||||
index_save_path: Path to the directory where the index files should be saved.
|
|
||||||
batch_size: Number of texts to process in each embedding batch.
|
|
||||||
"""
|
|
||||||
print(f"Loading paragraphs from {csv_path}")
|
|
||||||
try:
|
|
||||||
df = pd.read_csv(csv_path)
|
|
||||||
except FileNotFoundError:
|
|
||||||
print(f"Error: CSV file not found at {csv_path}. Please run the extraction script first.")
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error reading CSV file: {e}")
|
|
||||||
return
|
|
||||||
|
|
||||||
if "content" not in df.columns or "metadata" not in df.columns:
|
|
||||||
print("Error: CSV file must contain 'content' and 'metadata' columns.")
|
|
||||||
return
|
|
||||||
|
|
||||||
if df.empty:
|
|
||||||
print("Warning: Input CSV file is empty. No index will be built.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Prepare documents for FAISS
|
|
||||||
texts = df["content"].astype(str).tolist()
|
|
||||||
metadatas = []
|
|
||||||
try:
|
|
||||||
metadatas = [json.loads(m) for m in df["metadata"].tolist()]
|
|
||||||
print(f"Prepared {len(texts)} texts and {len(metadatas)} metadatas.")
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
print(f"Error parsing metadata JSON: {e}. Check the format in {csv_path}")
|
|
||||||
traceback.print_exc() # Print traceback for JSON errors
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing metadata: {e}")
|
|
||||||
traceback.print_exc() # Print traceback for other metadata errors
|
|
||||||
return
|
|
||||||
|
|
||||||
if not texts or not metadatas or len(texts) != len(metadatas):
|
|
||||||
print(f"Error: Mismatch or empty texts/metadatas. Texts: {len(texts)}, Metadatas: {len(metadatas)}")
|
|
||||||
return
|
|
||||||
|
|
||||||
print("Initializing embeddings model...")
|
|
||||||
try:
|
|
||||||
embeddings = CustomHuggingFaceEmbeddings()
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error initializing embeddings model: {e}")
|
|
||||||
traceback.print_exc()
|
|
||||||
return
|
|
||||||
print("Embeddings model initialized successfully.")
|
|
||||||
|
|
||||||
vectorstore = None
|
|
||||||
num_batches = math.ceil(len(texts) / batch_size)
|
|
||||||
print(f"Processing {len(texts)} texts in {num_batches} batches of size {batch_size}...")
|
|
||||||
|
|
||||||
for i in range(num_batches):
|
|
||||||
start_idx = i * batch_size
|
|
||||||
end_idx = min((i + 1) * batch_size, len(texts))
|
|
||||||
batch_texts = texts[start_idx:end_idx]
|
|
||||||
batch_metadatas = metadatas[start_idx:end_idx]
|
|
||||||
print(f" Processing batch {i + 1}/{num_batches} (indices {start_idx}-{end_idx - 1})...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
if i == 0:
|
|
||||||
# Initialize the vector store with the first batch
|
|
||||||
print(f" Initializing FAISS index with first batch...")
|
|
||||||
vectorstore = FAISS.from_texts(texts=batch_texts, embedding=embeddings, metadatas=batch_metadatas)
|
|
||||||
print(" FAISS index initialized.")
|
|
||||||
else:
|
|
||||||
# Add subsequent batches to the existing store
|
|
||||||
if vectorstore is None:
|
|
||||||
print("Error: vectorstore is None after first batch, cannot add more texts.")
|
|
||||||
return # Should not happen if first batch succeeded
|
|
||||||
print(f" Adding batch {i + 1} to FAISS index...")
|
|
||||||
vectorstore.add_texts(texts=batch_texts, metadatas=batch_metadatas)
|
|
||||||
print(f" Batch {i + 1} added.")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error processing batch {i + 1} (indices {start_idx}-{end_idx - 1}): {e}")
|
|
||||||
traceback.print_exc()
|
|
||||||
print("Stopping index creation due to error in batch processing.")
|
|
||||||
return # Exit if any batch fails
|
|
||||||
|
|
||||||
if vectorstore is None:
|
|
||||||
print("Error: Failed to create or add any data to the vectorstore.")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Save the completed index
|
|
||||||
try:
|
|
||||||
print(f"Attempting to save final FAISS index files to directory: {index_save_path}")
|
|
||||||
# Ensure the target directory exists before saving
|
|
||||||
Path(index_save_path).mkdir(parents=True, exist_ok=True)
|
|
||||||
vectorstore.save_local(index_save_path)
|
|
||||||
print(f"Successfully saved final FAISS index files (index.faiss, index.pkl) to: {index_save_path}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error during final vectorstore.save_local to {index_save_path}: {e}")
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Define paths relative to this script or use absolute paths
|
|
||||||
PROCESSED_DIR = Path("data/processed")
|
|
||||||
INPUT_CSV = str(PROCESSED_DIR / "paragraphs.csv")
|
|
||||||
# FAISS save_local will save index.faiss and index.pkl in this directory
|
|
||||||
INDEX_SAVE_DIR = str(PROCESSED_DIR) # Save directly to processed dir
|
|
||||||
|
|
||||||
build_faiss_index_from_csv(INPUT_CSV, INDEX_SAVE_DIR, batch_size=128)
|
|
@ -1,30 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
# This script is taken from https://github.com/StonyBrookNLP/musique with slight modifications
|
|
||||||
|
|
||||||
set -e
|
|
||||||
set -x
|
|
||||||
|
|
||||||
# If gdown doesn't work, you can download files from mentioned URLs manually
|
|
||||||
# and put them at appropriate locations.
|
|
||||||
pip install gdown
|
|
||||||
|
|
||||||
ZIP_NAME="musique_v1.0.zip"
|
|
||||||
|
|
||||||
# URL: https://drive.google.com/file/d/1tGdADlNjWFaHLeZZGShh2IRcpO6Lv24h/view?usp=sharing
|
|
||||||
gdown --id 1tGdADlNjWFaHLeZZGShh2IRcpO6Lv24h --output $ZIP_NAME
|
|
||||||
|
|
||||||
TARGET_DIR="./data/raw"
|
|
||||||
mkdir -p $TARGET_DIR
|
|
||||||
unzip -o $(basename $ZIP_NAME) -d $TARGET_DIR # Extract directly into target
|
|
||||||
|
|
||||||
# Move contents from the extracted 'data' folder up one level
|
|
||||||
mv $TARGET_DIR/data/* $TARGET_DIR/
|
|
||||||
|
|
||||||
# Clean up the empty directory and the zip
|
|
||||||
rm -rf $TARGET_DIR/data
|
|
||||||
rm $ZIP_NAME
|
|
||||||
|
|
||||||
# TODO: prevent these from zipping in.
|
|
||||||
rm -rf __MACOSX
|
|
||||||
# Clean up potential extracted .DS_Store
|
|
||||||
rm -f $TARGET_DIR/.DS_Store
|
|
@ -1,101 +0,0 @@
|
|||||||
import json
|
|
||||||
import sys
|
|
||||||
from collections import defaultdict # Use defaultdict for cleaner accumulation
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
# Add project root to Python path if needed (adjust relative path as necessary)
|
|
||||||
# project_root = Path(__file__).resolve().parent.parent
|
|
||||||
# sys.path.append(str(project_root))
|
|
||||||
# from config import logger # Assuming you have a logger setup
|
|
||||||
|
|
||||||
|
|
||||||
def extract_unique_paragraphs(input_paths: list[str], output_csv_path: str) -> None:
|
|
||||||
"""Extracts unique paragraphs from specified JSONL files.
|
|
||||||
|
|
||||||
Reads Musique JSONL files (train, dev, test), finds unique paragraphs
|
|
||||||
(regardless of is_supporting flag), combines title and text,
|
|
||||||
tracks source question IDs, and saves to CSV.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_paths: A list of paths to the input JSONL files.
|
|
||||||
output_csv_path: Path to save the output CSV file.
|
|
||||||
"""
|
|
||||||
output_dir = Path(output_csv_path).parent
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Use paragraph content as key, value is the set of source question IDs
|
|
||||||
paragraphs_data = defaultdict(set)
|
|
||||||
print("Starting paragraph extraction (including non-supporting)...")
|
|
||||||
|
|
||||||
for file_path in input_paths:
|
|
||||||
print(f"Processing file: {file_path}")
|
|
||||||
try:
|
|
||||||
with open(file_path, "r", encoding="utf-8") as infile:
|
|
||||||
for line_num, line in enumerate(infile, 1):
|
|
||||||
try:
|
|
||||||
data = json.loads(line)
|
|
||||||
main_question_id = data.get("id")
|
|
||||||
if not main_question_id:
|
|
||||||
print(f"Warning: Missing 'id' in line {line_num} of {file_path}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
for p in data.get("paragraphs", []):
|
|
||||||
title = p.get("title", "No Title")
|
|
||||||
text = p.get("paragraph_text", "")
|
|
||||||
content = f"{title}\n{text}".strip()
|
|
||||||
|
|
||||||
if not content:
|
|
||||||
continue # Skip empty paragraphs
|
|
||||||
|
|
||||||
paragraphs_data[content].add(main_question_id)
|
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
print(f"Warning: Skipping invalid JSON in line {line_num} of {file_path}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Warning: Error processing line {line_num} in {file_path}: {e}")
|
|
||||||
except FileNotFoundError:
|
|
||||||
print(f"Error: Input file not found: {file_path}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error reading file {file_path}: {e}")
|
|
||||||
|
|
||||||
print(f"Found {len(paragraphs_data)} unique paragraphs (supporting and non-supporting).")
|
|
||||||
|
|
||||||
# Prepare data for DataFrame
|
|
||||||
output_list = []
|
|
||||||
sorted_content = sorted(paragraphs_data.keys())
|
|
||||||
for chunk_id, content in enumerate(sorted_content, 1):
|
|
||||||
question_ids = paragraphs_data[content]
|
|
||||||
metadata = {"source_question_ids": sorted(list(question_ids))}
|
|
||||||
output_list.append(
|
|
||||||
{
|
|
||||||
"chunk_id": chunk_id,
|
|
||||||
"content": content,
|
|
||||||
"metadata": json.dumps(metadata), # Store metadata as JSON string
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
if not output_list:
|
|
||||||
print("No paragraphs found to save.")
|
|
||||||
return
|
|
||||||
df = pd.DataFrame(output_list)
|
|
||||||
try:
|
|
||||||
df.to_csv(output_csv_path, index=False)
|
|
||||||
print(f"Successfully saved unique paragraphs to {output_csv_path}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error saving CSV file: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
RAW_DIR = Path("data/raw")
|
|
||||||
PROCESSED_DIR = Path("data/processed")
|
|
||||||
|
|
||||||
input_files = [
|
|
||||||
str(RAW_DIR / "musique_ans_v1.0_train.jsonl"),
|
|
||||||
str(RAW_DIR / "musique_ans_v1.0_dev.jsonl"),
|
|
||||||
str(RAW_DIR / "musique_ans_v1.0_test.jsonl"),
|
|
||||||
]
|
|
||||||
output_csv = str(PROCESSED_DIR / "paragraphs.csv")
|
|
||||||
|
|
||||||
extract_unique_paragraphs(input_files, output_csv)
|
|
@ -1,155 +0,0 @@
|
|||||||
"""Prepares a deterministic sampled dev set (questions_dev.jsonl) from raw Musique dev data."""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
from collections import defaultdict
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
def transform_musique_dev_data(input_path: str, output_path: str, sample_config: dict) -> None:
|
|
||||||
"""Transforms Musique dev data with deterministic stratified sampling using uniform selection from sorted lists.
|
|
||||||
|
|
||||||
Reads dev data, categorizes by hop type (2, 3, 4), sorts categories by ID,
|
|
||||||
selects N samples uniformly spaced from each sorted category based on sample_config,
|
|
||||||
combines, sorts final list by ID, combines answers/aliases, extracts supporting paras,
|
|
||||||
and writes the transformed data to output_path.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_path: Path to the input JSONL file (e.g., data/raw/musique_ans_v1.0_dev.jsonl).
|
|
||||||
output_path: Path to the output JSONL file (e.g., data/processed/questions_dev.jsonl).
|
|
||||||
sample_config: Dictionary specifying samples per hop type (e.g., {"2hop": 20, "3hop": 15, "4hop": 15}).
|
|
||||||
"""
|
|
||||||
output_dir = Path(output_path).parent
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
print(f"Reading all data from {input_path} for dev sampling...")
|
|
||||||
all_data = []
|
|
||||||
try:
|
|
||||||
with open(input_path, "r", encoding="utf-8") as infile:
|
|
||||||
for line_num, line in enumerate(infile, 1):
|
|
||||||
try:
|
|
||||||
data = json.loads(line)
|
|
||||||
if "id" in data:
|
|
||||||
all_data.append(data)
|
|
||||||
else:
|
|
||||||
print(f"Warning: Skipping line {line_num} due to missing 'id' field in {input_path}")
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
print(f"Warning: Skipping invalid JSON in line {line_num} of {input_path}")
|
|
||||||
except FileNotFoundError:
|
|
||||||
print(f"Error: Input file not found at {input_path}")
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error reading file {input_path}: {e}")
|
|
||||||
return
|
|
||||||
print(f"Read {len(all_data)} total samples from dev set.")
|
|
||||||
|
|
||||||
# Categorize data by hop count (2hop, 3hop, 4hop)
|
|
||||||
categorized_data = defaultdict(list)
|
|
||||||
print("Categorizing data by hop type (2, 3, 4)...")
|
|
||||||
for data in all_data:
|
|
||||||
q_id = data["id"]
|
|
||||||
hop_type = None
|
|
||||||
if q_id.startswith("2hop"):
|
|
||||||
hop_type = "2hop"
|
|
||||||
elif q_id.startswith("3hop"):
|
|
||||||
hop_type = "3hop"
|
|
||||||
elif q_id.startswith("4hop"):
|
|
||||||
hop_type = "4hop"
|
|
||||||
|
|
||||||
if hop_type:
|
|
||||||
categorized_data[hop_type].append(data)
|
|
||||||
|
|
||||||
# Deterministic sampling using sorting and uniform index selection
|
|
||||||
final_sample_list = []
|
|
||||||
total_target = sum(sample_config.values())
|
|
||||||
print(f"Sampling deterministically via uniform selection from sorted lists to get {total_target} dev samples...")
|
|
||||||
|
|
||||||
for hop_type, target_count in sample_config.items():
|
|
||||||
available_samples = categorized_data.get(hop_type, [])
|
|
||||||
current_count = len(available_samples)
|
|
||||||
print(f" {hop_type}: Found {current_count} samples, need {target_count}.")
|
|
||||||
|
|
||||||
if current_count == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
available_samples.sort(key=lambda x: x["id"])
|
|
||||||
selected_samples_for_hop = []
|
|
||||||
if current_count < target_count:
|
|
||||||
print(f" Warning: Not enough samples for {hop_type}. Taking all {current_count} sorted samples.")
|
|
||||||
selected_samples_for_hop = available_samples
|
|
||||||
elif target_count > 0: # Ensure target_count is positive before selecting
|
|
||||||
print(f" Selecting {target_count} samples uniformly from {current_count}...")
|
|
||||||
# Calculate indices using integer interpretation of evenly spaced points
|
|
||||||
indices_to_take = [
|
|
||||||
int(i * (current_count - 1) / (target_count - 1)) if target_count > 1 else 0
|
|
||||||
for i in range(target_count)
|
|
||||||
] # Adjust index calc for edges
|
|
||||||
indices_to_take = sorted(list(set(indices_to_take))) # Ensure unique indices
|
|
||||||
# Simple fallback if uniqueness reduced count below target
|
|
||||||
while len(indices_to_take) < target_count and len(indices_to_take) < current_count:
|
|
||||||
next_val = indices_to_take[-1] + 1
|
|
||||||
if next_val < current_count:
|
|
||||||
indices_to_take.append(next_val)
|
|
||||||
else: # Cannot add more unique indices
|
|
||||||
break
|
|
||||||
selected_samples_for_hop = [
|
|
||||||
available_samples[idx] for idx in indices_to_take[:target_count]
|
|
||||||
] # Select based on unique indices, capped at target
|
|
||||||
|
|
||||||
final_sample_list.extend(selected_samples_for_hop)
|
|
||||||
|
|
||||||
print(f"Selected {len(final_sample_list)} dev samples in total.")
|
|
||||||
|
|
||||||
# Sort the final combined list by ID for consistent output order
|
|
||||||
print("Sorting the final combined dev sample list by ID...")
|
|
||||||
final_sample_list.sort(key=lambda x: x["id"])
|
|
||||||
|
|
||||||
# Process and write the selected samples
|
|
||||||
print(f"Processing and writing {len(final_sample_list)} selected dev samples to {output_path}...")
|
|
||||||
count = 0
|
|
||||||
try:
|
|
||||||
with open(output_path, "w", encoding="utf-8") as outfile:
|
|
||||||
for data in final_sample_list:
|
|
||||||
try:
|
|
||||||
supporting_paragraphs = [
|
|
||||||
p["paragraph_text"] for p in data.get("paragraphs", []) if p.get("is_supporting", False)
|
|
||||||
]
|
|
||||||
main_answer = data.get("answer", "")
|
|
||||||
aliases = data.get("answer_aliases", [])
|
|
||||||
all_answers = [main_answer] + (aliases if isinstance(aliases, list) else [])
|
|
||||||
valid_answers = [str(ans).strip() for ans in all_answers if ans and str(ans).strip()]
|
|
||||||
unique_valid_answers = list(set(valid_answers)) # Keep unique, don't sort alphabetically
|
|
||||||
combined_answer_str = " OR ".join(unique_valid_answers)
|
|
||||||
|
|
||||||
output_data = {
|
|
||||||
"id": data.get("id"),
|
|
||||||
"question": data.get("question"),
|
|
||||||
"answer": combined_answer_str,
|
|
||||||
"supporting_paragraphs": supporting_paragraphs,
|
|
||||||
}
|
|
||||||
outfile.write(json.dumps(output_data) + "\n")
|
|
||||||
count += 1
|
|
||||||
except KeyError as e:
|
|
||||||
print(f"Skipping sample due to missing key {e}: {data.get('id')}")
|
|
||||||
print(f"Successfully processed and wrote {count} dev samples.")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An unexpected error occurred during writing: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Define file paths relative to the project root
|
|
||||||
# Ensure this script is run from the project root or adjust paths accordingly
|
|
||||||
RAW_DIR = Path("data/raw")
|
|
||||||
PROCESSED_DIR = Path("data/processed")
|
|
||||||
|
|
||||||
# Define sampling configuration for the dev set
|
|
||||||
DEV_SAMPLING_CONFIG = {"2hop": 20, "3hop": 15, "4hop": 15} # Total = 50
|
|
||||||
|
|
||||||
INPUT_FILE = RAW_DIR / "musique_ans_v1.0_dev.jsonl"
|
|
||||||
OUTPUT_FILE = PROCESSED_DIR / "questions_dev.jsonl"
|
|
||||||
|
|
||||||
transform_musique_dev_data(str(INPUT_FILE), str(OUTPUT_FILE), DEV_SAMPLING_CONFIG)
|
|
||||||
|
|
||||||
print(f"\nMusique DEV JSONL transformation and deterministic sampling complete.")
|
|
@ -1,172 +0,0 @@
|
|||||||
import json
|
|
||||||
import math # Keep math import
|
|
||||||
import os
|
|
||||||
import re # Import re for parsing ID
|
|
||||||
from collections import defaultdict
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# import random # No longer needed
|
|
||||||
# SEED = 42 # No longer needed
|
|
||||||
# random.seed(SEED) # No longer needed
|
|
||||||
|
|
||||||
|
|
||||||
def transform_musique_data(input_path: str, output_path: str, sample_config: dict) -> None:
|
|
||||||
"""Transforms Musique data with deterministic stratified sampling using uniform selection from sorted lists.
|
|
||||||
|
|
||||||
Reads data, categorizes by detailed hop type, sorts categories by ID,
|
|
||||||
selects N samples uniformly spaced from each sorted category,
|
|
||||||
combines, sorts final list by ID, and writes to output.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_path: Path to the input JSONL file.
|
|
||||||
output_path: Path to the output JSONL file.
|
|
||||||
sample_config: Dictionary specifying samples per detailed hop type (e.g., {"2hop": 400, "3hop1": 150, ...}).
|
|
||||||
"""
|
|
||||||
output_dir = Path(output_path).parent
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
print(f"Reading all data from {input_path} for sampling...")
|
|
||||||
all_data = []
|
|
||||||
try:
|
|
||||||
with open(input_path, "r", encoding="utf-8") as infile:
|
|
||||||
for line_num, line in enumerate(infile, 1):
|
|
||||||
try:
|
|
||||||
data = json.loads(line)
|
|
||||||
if "id" in data:
|
|
||||||
all_data.append(data)
|
|
||||||
else:
|
|
||||||
print(f"Warning: Skipping line {line_num} due to missing 'id' field in {input_path}")
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
print(f"Warning: Skipping invalid JSON in line {line_num} of {input_path}")
|
|
||||||
except FileNotFoundError:
|
|
||||||
print(f"Error: Input file not found at {input_path}")
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error reading file {input_path}: {e}")
|
|
||||||
return
|
|
||||||
print(f"Read {len(all_data)} total samples with IDs.")
|
|
||||||
|
|
||||||
# Detailed Categorization by hop type
|
|
||||||
categorized_data = defaultdict(list)
|
|
||||||
print("Categorizing data by detailed hop type (e.g., 3hop1, 4hop2)...")
|
|
||||||
for data in all_data:
|
|
||||||
q_id = data["id"]
|
|
||||||
match = re.match(r"^(2hop|3hop[12]|4hop[123])__", q_id)
|
|
||||||
if match:
|
|
||||||
detailed_hop_type = match.group(1)
|
|
||||||
categorized_data[detailed_hop_type].append(data)
|
|
||||||
# else: # Optional: log if an ID doesn't match expected pattern
|
|
||||||
# print(f"Warning: ID {q_id} does not match expected hop pattern.")
|
|
||||||
|
|
||||||
# Deterministic sampling using sorting and uniform index selection
|
|
||||||
final_sample_list = []
|
|
||||||
total_target = sum(sample_config.values())
|
|
||||||
print(f"Sampling deterministically via uniform selection from sorted lists to get {total_target} samples...")
|
|
||||||
# Check if all requested hop types exist in config
|
|
||||||
for hop_type in sample_config.keys():
|
|
||||||
if hop_type not in categorized_data:
|
|
||||||
print(f"Warning: Hop type '{hop_type}' requested in config but not found in data.")
|
|
||||||
|
|
||||||
for hop_type, target_count in sample_config.items():
|
|
||||||
available_samples = categorized_data.get(hop_type, [])
|
|
||||||
current_count = len(available_samples)
|
|
||||||
print(f" {hop_type}: Found {current_count} samples, need {target_count}.")
|
|
||||||
|
|
||||||
if current_count == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Sort the list for this category by ID
|
|
||||||
available_samples.sort(key=lambda x: x["id"])
|
|
||||||
|
|
||||||
selected_samples_for_hop = []
|
|
||||||
if current_count < target_count:
|
|
||||||
print(f" Warning: Not enough samples for {hop_type}. Taking all {current_count} sorted samples.")
|
|
||||||
selected_samples_for_hop = available_samples
|
|
||||||
else:
|
|
||||||
# Select target_count indices spread uniformly across the available samples
|
|
||||||
print(f" Selecting {target_count} samples uniformly from {current_count}...")
|
|
||||||
# Calculate indices using integer interpretation of evenly spaced points
|
|
||||||
indices_to_take = [int(i * current_count / target_count) for i in range(target_count)]
|
|
||||||
# Ensure uniqueness in case of rounding issues with small numbers (though unlikely here)
|
|
||||||
indices_to_take = sorted(list(set(indices_to_take)))
|
|
||||||
# Adjust if rounding resulted in fewer than target_count unique indices
|
|
||||||
while len(indices_to_take) < target_count:
|
|
||||||
# This is a fallback, shouldn't happen if current_count >= target_count
|
|
||||||
# Add indices from the end if needed, avoiding duplicates
|
|
||||||
next_idx = indices_to_take[-1] + 1
|
|
||||||
if next_idx < current_count and next_idx not in indices_to_take:
|
|
||||||
indices_to_take.append(next_idx)
|
|
||||||
else: # Should not be reachable if logic is sound
|
|
||||||
break
|
|
||||||
|
|
||||||
# Select samples at the calculated indices
|
|
||||||
selected_samples_for_hop = [
|
|
||||||
available_samples[idx] for idx in indices_to_take[:target_count]
|
|
||||||
] # Ensure we take exactly target_count
|
|
||||||
|
|
||||||
final_sample_list.extend(selected_samples_for_hop)
|
|
||||||
|
|
||||||
print(f"Selected {len(final_sample_list)} samples in total.")
|
|
||||||
|
|
||||||
# Sort the final combined list by ID for consistent output order
|
|
||||||
print("Sorting the final combined sample list by ID...")
|
|
||||||
final_sample_list.sort(key=lambda x: x["id"])
|
|
||||||
|
|
||||||
# Process and write the selected samples
|
|
||||||
print(f"Processing and writing {len(final_sample_list)} selected samples to {output_path}...")
|
|
||||||
count = 0
|
|
||||||
try:
|
|
||||||
with open(output_path, "w", encoding="utf-8") as outfile:
|
|
||||||
for data in final_sample_list:
|
|
||||||
try:
|
|
||||||
supporting_paragraphs = [
|
|
||||||
p["paragraph_text"] for p in data.get("paragraphs", []) if p.get("is_supporting", False)
|
|
||||||
]
|
|
||||||
|
|
||||||
main_answer = data.get("answer", "")
|
|
||||||
aliases = data.get("answer_aliases", [])
|
|
||||||
|
|
||||||
all_answers = [main_answer] + (aliases if isinstance(aliases, list) else [])
|
|
||||||
valid_answers = [str(ans).strip() for ans in all_answers if ans and str(ans).strip()]
|
|
||||||
unique_valid_answers = list(set(valid_answers))
|
|
||||||
|
|
||||||
combined_answer_str = " OR ".join(unique_valid_answers)
|
|
||||||
|
|
||||||
output_data = {
|
|
||||||
"id": data.get("id"),
|
|
||||||
"question": data.get("question"),
|
|
||||||
"answer": combined_answer_str,
|
|
||||||
"supporting_paragraphs": supporting_paragraphs,
|
|
||||||
}
|
|
||||||
outfile.write(json.dumps(output_data) + "\n")
|
|
||||||
count += 1
|
|
||||||
except KeyError as e:
|
|
||||||
print(f"Skipping sample due to missing key {e}: {data.get('id')}")
|
|
||||||
print(f"Successfully processed and wrote {count} samples.")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An unexpected error occurred during writing: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Define file paths
|
|
||||||
RAW_DIR = Path("data/raw")
|
|
||||||
PROCESSED_DIR = Path("data/processed")
|
|
||||||
|
|
||||||
# Define detailed sampling configuration
|
|
||||||
SAMPLING_CONFIG = {
|
|
||||||
"2hop": 400,
|
|
||||||
"3hop1": 150,
|
|
||||||
"3hop2": 150,
|
|
||||||
"4hop1": 100,
|
|
||||||
"4hop2": 100,
|
|
||||||
"4hop3": 100,
|
|
||||||
} # Total = 1000
|
|
||||||
|
|
||||||
transform_musique_data(
|
|
||||||
str(RAW_DIR / "musique_ans_v1.0_train.jsonl"), str(PROCESSED_DIR / "questions.jsonl"), SAMPLING_CONFIG
|
|
||||||
)
|
|
||||||
|
|
||||||
print(
|
|
||||||
"\nMusique JSONL transformation and detailed deterministic sampling (uniform selection from sorted) complete."
|
|
||||||
)
|
|
||||||
# Note: Dev/Test files are not processed by default with this sampling logic.
|
|
@ -1 +0,0 @@
|
|||||||
Subproject commit 7e60ab26825a452f8ee8eb19799d1cb6c1746326
|
|
Loading…
Reference in new issue