parent
bd02305efb
commit
14ef79a4f5
@ -0,0 +1,659 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from flashrag.config import Config
|
||||
from flashrag.generator.generator import BaseGenerator
|
||||
from flashrag.pipeline import BasicPipeline
|
||||
from flashrag.retriever.retriever import BaseTextRetriever
|
||||
from flashrag.utils import get_dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from config import logger
|
||||
from src.agent import Agent, AgenticOutputs
|
||||
from src.prompts import build_user_prompt, get_system_prompt
|
||||
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
|
||||
|
||||
|
||||
def retry(max_retries=10, sleep=1):
|
||||
"""Decorator to retry a function with exponential backoff."""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
func_name = func.__name__
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"Attempt {attempt + 1} of {func_name} failed: {e}")
|
||||
if attempt == max_retries - 1:
|
||||
logger.error(f"Function {func_name} failed after {max_retries} retries.", exc_info=True)
|
||||
raise e
|
||||
backoff_time = sleep * (2**attempt)
|
||||
logger.info(f"Retrying {func_name} in {backoff_time:.2f} seconds...")
|
||||
time.sleep(backoff_time)
|
||||
logger.error(f"Function {func_name} retry logic finished unexpectedly.")
|
||||
return None
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class RemoteRetriever(BaseTextRetriever):
|
||||
"""A wrapper for remote retriever service with retry logic and logging."""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
"""Initializes the RemoteRetriever."""
|
||||
super().__init__(config)
|
||||
self.remote_url = f"http://{getattr(config, 'remote_retriever_url', 'localhost:8001')}"
|
||||
self.topk = getattr(config, "retriever_topk", 5)
|
||||
logger.info(f"🔗 Remote retriever URL: {self.remote_url}")
|
||||
|
||||
@retry(max_retries=3, sleep=2)
|
||||
def _search(self, query: str, num: int | None = None, return_score: bool = False) -> list[dict]:
|
||||
"""Search for documents using the remote retriever service."""
|
||||
num = num if num is not None else self.topk
|
||||
url = f"{self.remote_url}/search"
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
json={"query": query, "top_n": num, "return_score": return_score},
|
||||
timeout=30,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
results = response.json()
|
||||
return results
|
||||
except requests.exceptions.Timeout:
|
||||
logger.error(f"Search request timed out after 30 seconds for query: {query[:50]}...")
|
||||
raise
|
||||
except requests.exceptions.ConnectionError:
|
||||
logger.error(f"Could not connect to search service at {url}")
|
||||
raise
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Search request failed: {e}", exc_info=True)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected search error: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
@retry(max_retries=3, sleep=2)
|
||||
def _batch_search(
|
||||
self, queries: list[str], num: int | None = None, return_score: bool = False
|
||||
) -> list[list[dict]]:
|
||||
"""Batch search for documents using the remote retriever service."""
|
||||
num = num if num is not None else self.topk
|
||||
url = f"{self.remote_url}/batch_search"
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
json={"query": queries, "top_n": num, "return_score": return_score},
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
return results
|
||||
except requests.exceptions.Timeout:
|
||||
logger.error(f"Batch search request timed out after 60 seconds for {len(queries)} queries.")
|
||||
raise
|
||||
except requests.exceptions.ConnectionError:
|
||||
logger.error(f"Could not connect to batch search service at {url}")
|
||||
raise
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Batch search request failed: {e}", exc_info=True)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected batch search error: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
class ReSearchPipeline(BasicPipeline):
|
||||
"""Pipeline for ReSearch method using Agent for generation and tool use."""
|
||||
|
||||
def __init__(
|
||||
self, config: Config, retriever: BaseTextRetriever | None = None, generator: BaseGenerator | None = None
|
||||
):
|
||||
"""Initializes the ReSearchPipeline."""
|
||||
super().__init__(config)
|
||||
logger.info("🔧 Initializing ReSearchPipeline...")
|
||||
|
||||
self.retriever = retriever or RemoteRetriever(config)
|
||||
|
||||
self.generator = generator or SGLRemoteGenerator(config)
|
||||
|
||||
try:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(config.generator_model_path, trust_remote_code=True)
|
||||
if not self.tokenizer.pad_token:
|
||||
logger.warning("Tokenizer does not have a pad token; setting to eos_token.")
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
self.tokenizer.padding_side = "left"
|
||||
logger.info("✅ Tokenizer initialized.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize tokenizer: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
tokenizer_name = self.tokenizer.name_or_path.lower()
|
||||
|
||||
if "deepseek-ai/deepseek-r1-distill" in tokenizer_name:
|
||||
adapter = R1DistilTokenizerAdapter()
|
||||
elif "llama" in tokenizer_name:
|
||||
adapter = LlamaTokenizerAdapter()
|
||||
else:
|
||||
logger.warning(f"Unknown tokenizer type '{tokenizer_name}', defaulting to R1DistilTokenizerAdapter.")
|
||||
adapter = R1DistilTokenizerAdapter()
|
||||
logger.info(f"🔩 Using Tokenizer Adapter: {type(adapter).__name__}")
|
||||
|
||||
def retriever_search(query: str, return_type=str, results: int = 2):
|
||||
try:
|
||||
search_results = self.retriever._search(query, num=results)
|
||||
return self.format_search_results(search_results)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during agent's retriever search for query '{query[:50]}...': {e}", exc_info=True)
|
||||
return "<information>Search failed due to an internal error."
|
||||
|
||||
self.agent = Agent(adapter, search_fn=retriever_search)
|
||||
logger.info("✅ Agent initialized.")
|
||||
logger.info("✅ ReSearchPipeline initialized successfully.")
|
||||
|
||||
def format_search_results(self, search_results: list[dict]) -> str:
|
||||
"""Formats search results into a string for the agent prompt."""
|
||||
if not search_results:
|
||||
return "<information>No results found.</information>"
|
||||
max_content_len = 500
|
||||
formatted = "\n-------\n".join(
|
||||
[
|
||||
f"Result {i + 1}: {r.get('contents', 'N/A')[:max_content_len]}{'...' if len(r.get('contents', '')) > max_content_len else ''}"
|
||||
for i, r in enumerate(search_results)
|
||||
]
|
||||
)
|
||||
formatted_str = f"<information>{formatted}</information>"
|
||||
|
||||
return formatted_str
|
||||
|
||||
def extract_search_query(self, text: str) -> str | None:
|
||||
"""Extract search query from text between <search> tags."""
|
||||
pattern = re.compile(r"<search>(.*?)</search>", re.DOTALL)
|
||||
matches = pattern.findall(text)
|
||||
if matches:
|
||||
query = matches[-1].strip()
|
||||
return query
|
||||
return None
|
||||
|
||||
def extract_answer(self, text: str) -> str | None:
|
||||
"""Extract answer from text between <answer> tags."""
|
||||
pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL)
|
||||
matches = pattern.findall(text)
|
||||
if matches:
|
||||
answer = matches[-1].strip()
|
||||
|
||||
return answer
|
||||
|
||||
return None
|
||||
|
||||
def run(self, dataset, do_eval: bool = True, pred_process_fun=None):
|
||||
"""Runs the ReSearch pipeline on the dataset using the Agent."""
|
||||
logger.info(f"🏃 Starting ReSearch pipeline run with {len(dataset)} items...")
|
||||
|
||||
try:
|
||||
questions = [item.question if hasattr(item, "question") else item["question"] for item in dataset]
|
||||
|
||||
except (KeyError, AttributeError, TypeError) as e:
|
||||
logger.error(f"Failed to extract questions from dataset items. Error: {e}", exc_info=True)
|
||||
logger.error("Ensure dataset items have a 'question' key or attribute.")
|
||||
return dataset
|
||||
|
||||
agent_max_generations = getattr(self.config, "agent_max_generations", 20)
|
||||
generator_max_output_len = getattr(self.config, "generator_max_output_len", 1024)
|
||||
|
||||
try:
|
||||
logger.info(f"🤖 Running agent inference for {len(questions)} questions...")
|
||||
agent_outputs: AgenticOutputs = self.agent.run_agent(
|
||||
generate_fn=self.generator.generate,
|
||||
tokenizer=self.tokenizer,
|
||||
questions=questions,
|
||||
max_generations=agent_max_generations,
|
||||
max_new_tokens=generator_max_output_len,
|
||||
)
|
||||
final_responses = agent_outputs.final_response_str
|
||||
logger.info(f"✅ Agent inference completed. Received {len(final_responses)} final responses.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent run failed during inference: {e}", exc_info=True)
|
||||
logger.warning("Agent run failed, attempting evaluation with potentially incomplete results.")
|
||||
for item in dataset:
|
||||
if hasattr(item, "update_output"):
|
||||
item.update_output("pred", "AGENT_ERROR")
|
||||
elif isinstance(item, dict):
|
||||
item["pred"] = "AGENT_ERROR"
|
||||
|
||||
logger.info("📝 Extracting answers and updating dataset items...")
|
||||
num_updated = 0
|
||||
num_missing_answers = 0
|
||||
if len(final_responses) == len(dataset):
|
||||
for i, item in enumerate(dataset):
|
||||
response = final_responses[i]
|
||||
answer = self.extract_answer(response)
|
||||
pred_to_save = answer if answer is not None else ""
|
||||
|
||||
if answer is None:
|
||||
num_missing_answers += 1
|
||||
|
||||
if hasattr(item, "update_output"):
|
||||
item.update_output("pred", pred_to_save)
|
||||
item.update_output("final_response", response)
|
||||
num_updated += 1
|
||||
elif isinstance(item, dict):
|
||||
item["pred"] = pred_to_save
|
||||
item["final_response"] = response
|
||||
num_updated += 1
|
||||
else:
|
||||
logger.warning(f"Item {i} has unknown type {type(item)}, cannot update with prediction.")
|
||||
|
||||
logger.info(f"Updated {num_updated}/{len(dataset)} dataset items with predictions.")
|
||||
if num_missing_answers > 0:
|
||||
logger.warning(f"{num_missing_answers} items had no <answer> tag.")
|
||||
else:
|
||||
logger.error(
|
||||
f"Mismatch between dataset size ({len(dataset)}) and number of agent responses ({len(final_responses)}). Cannot reliably update dataset."
|
||||
)
|
||||
for item in dataset:
|
||||
if hasattr(item, "update_output"):
|
||||
item.update_output("pred", "RESPONSE_COUNT_MISMATCH")
|
||||
elif isinstance(item, dict):
|
||||
item["pred"] = "RESPONSE_COUNT_MISMATCH"
|
||||
|
||||
if do_eval:
|
||||
logger.info("📊 Evaluating results using BasicPipeline.evaluate...")
|
||||
try:
|
||||
dataset = self.evaluate(dataset, do_eval=True, pred_process_fun=pred_process_fun)
|
||||
logger.info("✅ Evaluation completed via base class method.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during BasicPipeline.evaluate: {e}", exc_info=True)
|
||||
logger.warning("Evaluation may be incomplete.")
|
||||
else:
|
||||
logger.info("Skipping evaluation step as do_eval=False.")
|
||||
|
||||
logger.info("✅ ReSearch pipeline run finished.")
|
||||
return dataset
|
||||
|
||||
|
||||
class SGLRemoteGenerator(BaseGenerator):
|
||||
"""Class for decoder-only generator, based on SGLang remote service."""
|
||||
|
||||
def __init__(self, config: Config):
|
||||
"""Initializes the SGLRemoteGenerator."""
|
||||
super().__init__(config)
|
||||
logger.info("🔧 Initializing SGLRemoteGenerator...")
|
||||
sgl_url = getattr(config, "sgl_remote_url", "localhost:8002")
|
||||
self.sgl_remote_url = f"http://{sgl_url}/generate"
|
||||
self.health_check_url = f"http://{sgl_url}/health"
|
||||
logger.info(f"🔗 Remote Generator URL: {self.sgl_remote_url}")
|
||||
self.model_path = getattr(config, "generator_model_path", None)
|
||||
if not self.model_path:
|
||||
logger.error("generator_model_path not found in config!")
|
||||
raise ValueError("generator_model_path is required for SGLRemoteGenerator")
|
||||
|
||||
try:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
|
||||
logger.info("✅ Tokenizer loaded for generator.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load tokenizer for generator from {self.model_path}: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
self.generation_params = getattr(config, "generation_params", {})
|
||||
self.config = config
|
||||
|
||||
self._check_health()
|
||||
|
||||
def _check_health(self):
|
||||
"""Checks the health of the remote generator service."""
|
||||
try:
|
||||
test_response = requests.get(self.health_check_url, timeout=5)
|
||||
test_response.raise_for_status()
|
||||
logger.info("✅ Remote generator service is available")
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Could not connect or verify remote generator service at {self.health_check_url}: {str(e)}")
|
||||
logger.warning("Please ensure the SGLang service is running and accessible.")
|
||||
|
||||
@retry(max_retries=5, sleep=2)
|
||||
def generate(
|
||||
self,
|
||||
input_list: list[str] | str,
|
||||
return_raw_output: bool = False,
|
||||
return_scores: bool = False,
|
||||
**params,
|
||||
) -> list[str] | tuple[list[str], list[list[float]]] | list[dict]:
|
||||
"""Generates text using the remote SGLang service."""
|
||||
if isinstance(input_list, str):
|
||||
input_list = [input_list]
|
||||
if not isinstance(input_list, list) or not all(isinstance(item, str) for item in input_list):
|
||||
raise ValueError("Input must be a string or a list of strings.")
|
||||
|
||||
batch_size = len(input_list)
|
||||
data_to_remote = {"text": input_list}
|
||||
|
||||
effective_params = deepcopy(self.generation_params)
|
||||
effective_params.update(params)
|
||||
|
||||
curr_sampling_params = {}
|
||||
if effective_params.get("do_sample", True) is False:
|
||||
curr_sampling_params["temperature"] = 0.0
|
||||
else:
|
||||
curr_sampling_params["temperature"] = effective_params.get(
|
||||
"temperature", getattr(self.config, "temperature", 0.7)
|
||||
)
|
||||
|
||||
default_max_tokens = getattr(self.config, "generator_max_output_len", 1024)
|
||||
curr_sampling_params["max_new_tokens"] = effective_params.get("max_new_tokens", default_max_tokens)
|
||||
|
||||
stop_sequences = effective_params.get("stop", [])
|
||||
if isinstance(stop_sequences, str):
|
||||
stop_sequences = [stop_sequences]
|
||||
if stop_sequences:
|
||||
curr_sampling_params["stop"] = stop_sequences
|
||||
|
||||
keys_to_remove = ["do_sample", "temperature", "max_new_tokens", "stop"]
|
||||
for key in keys_to_remove:
|
||||
effective_params.pop(key, None)
|
||||
|
||||
if "top_p" in effective_params:
|
||||
curr_sampling_params["top_p"] = effective_params["top_p"]
|
||||
if "top_k" in effective_params:
|
||||
curr_sampling_params["top_k"] = effective_params["top_k"]
|
||||
|
||||
data_to_remote["sampling_params"] = curr_sampling_params
|
||||
|
||||
if return_scores:
|
||||
data_to_remote["return_logprob"] = True
|
||||
data_to_remote["top_logprobs_num"] = getattr(self.config, "top_logprobs_num", 2)
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
self.sgl_remote_url, json=data_to_remote, timeout=120, headers={"Content-Type": "application/json"}
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
response_list = response.json()
|
||||
|
||||
if return_raw_output:
|
||||
return response_list
|
||||
|
||||
generated_text = []
|
||||
for item in response_list:
|
||||
text = item.get("text", "")
|
||||
finish_reason = item.get("meta_info", {}).get("finish_reason", {})
|
||||
matched_stop = finish_reason.get("matched")
|
||||
if matched_stop and curr_sampling_params.get("stop") and matched_stop in curr_sampling_params["stop"]:
|
||||
text += matched_stop
|
||||
generated_text.append(text)
|
||||
|
||||
if return_scores:
|
||||
scores = []
|
||||
for resp_item in response_list:
|
||||
logprobs_list = resp_item.get("meta_info", {}).get("output_token_logprobs", [])
|
||||
token_scores = [
|
||||
np.exp(logprob[0]) if (logprob and len(logprob) > 0) else 0.0 for logprob in logprobs_list
|
||||
]
|
||||
scores.append(token_scores)
|
||||
return generated_text, scores
|
||||
else:
|
||||
return generated_text
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
logger.error("Generation request timed out after 120 seconds.")
|
||||
raise
|
||||
except requests.exceptions.ConnectionError:
|
||||
logger.error(f"Could not connect to remote generator service at {self.sgl_remote_url}.")
|
||||
raise
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Network error during generation: {str(e)}", exc_info=True)
|
||||
raise
|
||||
except json.JSONDecodeError as e:
|
||||
response_text = "Unknown (error occurred before response object assignment)"
|
||||
if "response" in locals() and hasattr(response, "text"):
|
||||
response_text = response.text[:500]
|
||||
logger.error(
|
||||
f"Failed to decode JSON response from {self.sgl_remote_url}. Response text: {response_text}...",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during generation: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def load_dataset_items(config: Config, split: str) -> list[dict | object]:
|
||||
"""Loads dataset items using flashrag's get_dataset."""
|
||||
logger.info(f"📚 Loading dataset: {config.dataset_name}, Split: {split}")
|
||||
try:
|
||||
all_splits = get_dataset(config)
|
||||
if split not in all_splits:
|
||||
logger.error(
|
||||
f"Split '{split}' not found in dataset '{config.dataset_name}'. Available splits: {list(all_splits.keys())}"
|
||||
)
|
||||
return []
|
||||
dataset_items = all_splits[split]
|
||||
logger.info(f"Successfully loaded {len(dataset_items)} items for split '{split}'.")
|
||||
|
||||
return dataset_items
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
f"Dataset files not found for '{config.dataset_name}' in '{config.data_dir}'. Check config and paths."
|
||||
)
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading dataset using get_dataset: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
def save_results(args: argparse.Namespace, config: Config, result_dataset, run_duration: float):
|
||||
"""Saves summary and debug information."""
|
||||
logger.info("💾 Saving results...")
|
||||
summary_file = os.path.join(args.save_dir, f"{args.save_note}_summary.txt")
|
||||
debug_file = os.path.join(args.save_dir, f"{args.save_note}_debug.json")
|
||||
|
||||
num_items = len(result_dataset)
|
||||
|
||||
logger.info(f"Saving summary results to {summary_file}...")
|
||||
try:
|
||||
with open(summary_file, "w", encoding="utf-8") as f:
|
||||
f.write("EVALUATION SUMMARY\n")
|
||||
f.write("=================\n\n")
|
||||
f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
||||
f.write(f"Run Duration: {run_duration:.2f} seconds\n")
|
||||
f.write(f"Dataset: {config.dataset_name} ({args.split} split)\n")
|
||||
f.write(f"Model: {config.generator_model_path}\n")
|
||||
f.write(f"Retriever: {config.remote_retriever_url}\n")
|
||||
f.write(f"Agent Max Generations: {getattr(config, 'agent_max_generations', 'N/A')}\n")
|
||||
f.write(f"Generator Max Output Len: {getattr(config, 'generator_max_output_len', 'N/A')}\n\n")
|
||||
f.write(f"Total items processed: {num_items}\n")
|
||||
f.write("\nNote: Verification was skipped in this run.\n")
|
||||
f.write("Note: Overall metrics (like EM, F1) are usually printed to console by evaluate method.\n")
|
||||
|
||||
logger.info(f"✅ Summary saved to {summary_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving summary file '{summary_file}': {e}", exc_info=True)
|
||||
|
||||
logger.info(f"Saving debug information (predictions & responses) to {debug_file}...")
|
||||
try:
|
||||
debug_data = []
|
||||
for i, item in enumerate(result_dataset):
|
||||
item_data: dict[str, object] = {}
|
||||
|
||||
def get_item_value(data_item, key_or_attr: str) -> str | int | float | list | bool | None:
|
||||
if isinstance(data_item, dict):
|
||||
return data_item.get(key_or_attr)
|
||||
elif hasattr(data_item, key_or_attr):
|
||||
return getattr(data_item, key_or_attr)
|
||||
return None
|
||||
|
||||
item_data["item_index"] = i
|
||||
item_data["question"] = get_item_value(item, "question")
|
||||
item_data["prediction"] = get_item_value(item, "pred")
|
||||
item_data["final_response"] = get_item_value(item, "final_response")
|
||||
|
||||
gt_answer_val = None
|
||||
try:
|
||||
gt_answer_val = get_item_value(item, "answer")
|
||||
if gt_answer_val is None:
|
||||
answers_list = get_item_value(item, "answers")
|
||||
if isinstance(answers_list, list) and answers_list:
|
||||
raw_ans = answers_list[0]
|
||||
if isinstance(raw_ans, (str, int, float, bool)):
|
||||
gt_answer_val = raw_ans
|
||||
else:
|
||||
gt_answer_val = str(raw_ans)
|
||||
elif not isinstance(gt_answer_val, (str, int, float, bool)):
|
||||
gt_answer_val = str(gt_answer_val)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not safely get ground truth for item {i}: {e}")
|
||||
gt_answer_val = "ERROR_GETTING_ANSWER"
|
||||
item_data["ground_truth"] = gt_answer_val
|
||||
|
||||
eval_score_val = None
|
||||
try:
|
||||
eval_score_val = get_item_value(item, "score")
|
||||
if not isinstance(eval_score_val, (str, int, float, bool, type(None))):
|
||||
eval_score_val = str(eval_score_val)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not safely get score for item {i}: {e}")
|
||||
eval_score_val = "ERROR_GETTING_SCORE"
|
||||
item_data["eval_score"] = eval_score_val
|
||||
|
||||
debug_data.append(item_data)
|
||||
|
||||
with open(debug_file, "w", encoding="utf-8") as f:
|
||||
json.dump(debug_data, f, indent=2, ensure_ascii=False)
|
||||
logger.info(f"✅ Debug information saved to {debug_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving debug file '{debug_file}': {e}", exc_info=True)
|
||||
|
||||
|
||||
def research(args: argparse.Namespace, config: Config):
|
||||
"""Main function to run the research evaluation pipeline."""
|
||||
logger.info("🚀 Starting research pipeline execution...")
|
||||
start_time = time.time()
|
||||
|
||||
test_data = load_dataset_items(config, args.split)
|
||||
if not test_data:
|
||||
logger.error("Failed to load test data. Exiting.")
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("🏗️ Building ReSearchPipeline...")
|
||||
pipeline = ReSearchPipeline(config)
|
||||
logger.info("✅ Pipeline built successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize ReSearchPipeline: {e}", exc_info=True)
|
||||
return
|
||||
|
||||
agent_max_generations = getattr(config, "agent_max_generations", 20)
|
||||
generator_max_output_len = getattr(config, "generator_max_output_len", 1024)
|
||||
|
||||
try:
|
||||
logger.info("🏃 Starting pipeline run...")
|
||||
result_dataset = pipeline.run(test_data, do_eval=True)
|
||||
logger.info("✅ Pipeline run completed.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during pipeline run: {e}", exc_info=True)
|
||||
result_dataset = test_data
|
||||
logger.warning("Pipeline run failed, attempting to save inputs/partial results.")
|
||||
|
||||
run_duration = time.time() - start_time
|
||||
logger.info(f"Total run duration: {run_duration:.2f} seconds.")
|
||||
save_results(args, config, result_dataset, run_duration)
|
||||
|
||||
logger.info("🏁 Research pipeline execution finished.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Running ReSearch Evaluation Pipeline")
|
||||
parser.add_argument(
|
||||
"--config_path", type=str, default="./eval_config.yaml", help="Path to the main FlashRAG config file."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
default="bamboogle",
|
||||
help="Name of the dataset (must match config or data_dir structure).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split", type=str, default="test", help="Dataset split to evaluate (e.g., test, validation)."
|
||||
)
|
||||
parser.add_argument("--save_dir", type=str, default="./output_logs", help="Directory to save logs and results.")
|
||||
parser.add_argument("--save_note", type=str, default="research_run", help="A note to prepend to saved filenames.")
|
||||
|
||||
parser.add_argument("--data_dir", type=str, help="Override data directory specified in config.")
|
||||
parser.add_argument(
|
||||
"--sgl_remote_url", type=str, help="Override SGLang remote generator URL (e.g., localhost:8002)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remote_retriever_url", type=str, help="Override remote retriever URL (e.g., localhost:8001)."
|
||||
)
|
||||
parser.add_argument("--generator_model_path", type=str, help="Override generator model path specified in config.")
|
||||
parser.add_argument("--retriever_topk", type=int, help="Override retriever top K.")
|
||||
parser.add_argument("--generator_max_output_len", type=int, help="Override generator max output length.")
|
||||
parser.add_argument("--agent_max_generations", type=int, help="Override agent max interaction turns.")
|
||||
|
||||
args = parser.parse_args()
|
||||
logger.info(f"Starting evaluation script with arguments: {args}")
|
||||
|
||||
try:
|
||||
os.makedirs(args.save_dir, exist_ok=True)
|
||||
logger.info(f"💾 Logs and results will be saved to: {args.save_dir}")
|
||||
except OSError as e:
|
||||
logger.error(f"Could not create save directory '{args.save_dir}': {e}", exc_info=True)
|
||||
exit(1)
|
||||
|
||||
config_overrides = {
|
||||
k: v
|
||||
for k, v in vars(args).items()
|
||||
if v is not None
|
||||
and k
|
||||
not in [
|
||||
"config_path",
|
||||
"dataset_name",
|
||||
"split",
|
||||
"save_dir",
|
||||
"save_note",
|
||||
]
|
||||
}
|
||||
|
||||
logger.info(f"🔧 Loading configuration from: {args.config_path}")
|
||||
try:
|
||||
config = Config(args.config_path, config_dict=config_overrides)
|
||||
config.dataset_name = args.dataset_name
|
||||
if args.data_dir:
|
||||
config.data_dir = args.data_dir
|
||||
|
||||
logger.info(f"Effective data_dir: {getattr(config, 'data_dir', 'N/A')}")
|
||||
logger.info(f"Effective generator_model_path: {getattr(config, 'generator_model_path', 'N/A')}")
|
||||
logger.info(f"Effective sgl_remote_url: {getattr(config, 'sgl_remote_url', 'N/A')}")
|
||||
logger.info(f"Effective remote_retriever_url: {getattr(config, 'remote_retriever_url', 'N/A')}")
|
||||
|
||||
logger.info("✅ Config loaded and potentially overridden by CLI arguments.")
|
||||
|
||||
config["dataset_path"] = os.path.join(config.data_dir, config.dataset_name)
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Config file not found at '{args.config_path}'. Please check the path.")
|
||||
exit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading or processing configuration: {e}", exc_info=True)
|
||||
exit(1)
|
||||
|
||||
research(args, config)
|
@ -0,0 +1,68 @@
|
||||
import argparse
|
||||
import os
|
||||
import zipfile
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from config import DATA_DIR
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""Parse command line arguments.
|
||||
|
||||
Returns:
|
||||
argparse.Namespace: Parsed arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Download FlashRAG datasets from HuggingFace Hub")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default="RUC-NLPIR/FlashRAG_datasets",
|
||||
help="HuggingFace repository IDs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
type=str,
|
||||
default=DATA_DIR / "flashrag_datasets",
|
||||
help="Local directory to save model",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to download model."""
|
||||
args = parse_args()
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Configuration
|
||||
HF_TOKEN = os.getenv("HF_TOKEN")
|
||||
|
||||
ALLOW_PATTERNS = [
|
||||
"*retrieval-corpus*",
|
||||
"*bamboogle*",
|
||||
"*nq*",
|
||||
]
|
||||
|
||||
# Download the model
|
||||
snapshot_download(
|
||||
token=HF_TOKEN,
|
||||
repo_id=args.repo_id,
|
||||
local_dir=args.local_dir,
|
||||
repo_type="dataset",
|
||||
# ignore_patterns=IGNORE_PATTERNS,
|
||||
allow_patterns=ALLOW_PATTERNS,
|
||||
)
|
||||
|
||||
# unzip data/flashrag_datasets/retrieval-corpus/wiki18_100w.zip
|
||||
print("Unzipping wiki18_100w.zip. Might take a while...")
|
||||
zip_file_path = os.path.join(args.local_dir, "retrieval-corpus", "wiki18_100w.zip")
|
||||
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
|
||||
zip_ref.extractall(args.local_dir)
|
||||
|
||||
print(f"✅ Done: {args.repo_id} -> {args.local_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,43 @@
|
||||
"""Download and extract FlashRAG index."""
|
||||
|
||||
import os
|
||||
import zipfile
|
||||
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
from config import DATA_DIR
|
||||
|
||||
# Constants
|
||||
URL = "https://www.modelscope.cn/datasets/hhjinjiajie/FlashRAG_Dataset/resolve/master/retrieval_corpus/wiki18_100w_e5_index.zip"
|
||||
ZIP_NAME = "wiki18_100w_e5_index.zip"
|
||||
zip_path = DATA_DIR / ZIP_NAME
|
||||
|
||||
# Download with progress bar
|
||||
print("📥 Downloading index...")
|
||||
response = requests.get(URL, stream=True)
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
|
||||
with (
|
||||
open(zip_path, "wb") as f,
|
||||
tqdm(
|
||||
desc=ZIP_NAME,
|
||||
total=total_size,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as bar,
|
||||
):
|
||||
for data in response.iter_content(chunk_size=1024):
|
||||
size = f.write(data)
|
||||
bar.update(size)
|
||||
|
||||
# Extract
|
||||
print("📦 Extracting index...")
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||
zip_ref.extractall(DATA_DIR)
|
||||
|
||||
# Clean up zip
|
||||
os.remove(zip_path)
|
||||
print("✅ Download and extraction completed successfully!")
|
||||
print(f"Index file is at: {DATA_DIR}/data00/jiajie_jin/flashrag_indexes/wiki_dpr_100w/e5_flat_inner.index")
|
@ -0,0 +1,54 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from config import GENERATOR_MODEL_DIR, GENERATOR_MODEL_REPO_ID
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""Parse command line arguments.
|
||||
|
||||
Returns:
|
||||
argparse.Namespace: Parsed arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Download model from HuggingFace Hub")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default=GENERATOR_MODEL_REPO_ID,
|
||||
help="HuggingFace repository ID",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
type=str,
|
||||
default=GENERATOR_MODEL_DIR,
|
||||
help="Local directory to save model",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to download model."""
|
||||
args = parse_args()
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Configuration
|
||||
HF_TOKEN = os.getenv("HF_TOKEN")
|
||||
|
||||
print("Downloading model to", args.local_dir)
|
||||
|
||||
# Download the model
|
||||
snapshot_download(
|
||||
token=HF_TOKEN,
|
||||
repo_id=args.repo_id,
|
||||
local_dir=args.local_dir,
|
||||
repo_type="model",
|
||||
)
|
||||
print(f"✅ Done: {args.repo_id} -> {args.local_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,54 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from config import RETRIEVER_MODEL_DIR, RETRIEVER_MODEL_REPO_ID
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""Parse command line arguments.
|
||||
|
||||
Returns:
|
||||
argparse.Namespace: Parsed arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Download model from HuggingFace Hub")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default=RETRIEVER_MODEL_REPO_ID,
|
||||
help="HuggingFace repository ID",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local-dir",
|
||||
type=str,
|
||||
default=RETRIEVER_MODEL_DIR,
|
||||
help="Local directory to save model",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function to download model."""
|
||||
args = parse_args()
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Configuration
|
||||
HF_TOKEN = os.getenv("HF_TOKEN")
|
||||
|
||||
print("Downloading model to", args.local_dir)
|
||||
|
||||
# Download the model
|
||||
snapshot_download(
|
||||
token=HF_TOKEN,
|
||||
repo_id=args.repo_id,
|
||||
local_dir=args.local_dir,
|
||||
repo_type="model",
|
||||
)
|
||||
print(f"✅ Done: {args.repo_id} -> {args.local_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -0,0 +1,9 @@
|
||||
# ------------------------------------------------Environment Settings------------------------------------------------#
|
||||
gpu_id: "0"
|
||||
|
||||
# -------------------------------------------------Retrieval Settings------------------------------------------------#
|
||||
# If set the name, the model path will be find in global paths
|
||||
retrieval_method: "e5" # name or path of the retrieval model.
|
||||
index_path: "/mnt/nas/thinhlpg/data/data00/jiajie_jin/flashrag_indexes/wiki_dpr_100w/e5_flat_inner.index" # path to the indexed file
|
||||
faiss_gpu: False # whether use gpu to hold index
|
||||
corpus_path: "/mnt/nas/thinhlpg/code/DeepSearch/data/flashrag_datasets/wiki18_100w.jsonl" # path to corpus in '.jsonl' format that store the documents
|
@ -0,0 +1,125 @@
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to sys.path to allow importing config
|
||||
# Assuming the script is at DeepSearch/scripts/serving/serve_generator.py
|
||||
# The project root (DeepSearch) is parents[2]
|
||||
PROJ_ROOT = Path(__file__).resolve().parents[2]
|
||||
if str(PROJ_ROOT) not in sys.path:
|
||||
sys.path.append(str(PROJ_ROOT))
|
||||
|
||||
# Import after adjusting sys.path
|
||||
try:
|
||||
from config import (
|
||||
GENERATOR_MODEL_REPO_ID,
|
||||
GENERATOR_SERVER_PORT,
|
||||
MODEL_CONFIG,
|
||||
logger,
|
||||
)
|
||||
except ImportError as e:
|
||||
# Use print here as logger might not be available if import failed
|
||||
print(
|
||||
f"Error importing config: {e}. Make sure config.py is in the project root ({PROJ_ROOT}) and added to sys.path."
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def launch_sglang_server(
|
||||
model_id: str,
|
||||
port: int,
|
||||
context_length: int,
|
||||
host: str = "0.0.0.0",
|
||||
dtype: str = "bfloat16",
|
||||
) -> None:
|
||||
"""Launches the SGLang server using specified configurations.
|
||||
|
||||
Args:
|
||||
model_id: The Hugging Face repository ID of the model.
|
||||
port: The port number for the server.
|
||||
context_length: The maximum context length for the model.
|
||||
host: The host address for the server.
|
||||
dtype: The data type for the model (e.g., 'bfloat16', 'float16').
|
||||
"""
|
||||
command = [
|
||||
sys.executable, # Use the current Python interpreter
|
||||
"-m",
|
||||
"sglang.launch_server",
|
||||
"--model-path",
|
||||
model_id,
|
||||
"--context-length",
|
||||
str(context_length),
|
||||
"--enable-metrics",
|
||||
"--dtype",
|
||||
dtype,
|
||||
"--host",
|
||||
host,
|
||||
"--port",
|
||||
str(port),
|
||||
"--trust-remote-code",
|
||||
# Recommended by SGLang for stability sometimes
|
||||
"--disable-overlap",
|
||||
# Can sometimes cause issues
|
||||
"--disable-radix-cache",
|
||||
]
|
||||
|
||||
# Log the command clearly
|
||||
command_str = " ".join(command)
|
||||
logger.info(f"🚀 Launching SGLang server with command: {command_str}")
|
||||
|
||||
process = None # Initialize process to None
|
||||
try:
|
||||
# Use Popen to start the server process
|
||||
# It runs in the foreground relative to this script,
|
||||
# but allows us to catch KeyboardInterrupt cleanly.
|
||||
process = subprocess.Popen(command)
|
||||
# Wait for the process to complete (e.g., user interruption)
|
||||
process.wait()
|
||||
# Check return code after waiting
|
||||
if process.returncode != 0:
|
||||
logger.error(f"💥 SGLang server process exited with error code: {process.returncode}")
|
||||
sys.exit(process.returncode)
|
||||
else:
|
||||
logger.info("✅ SGLang server process finished gracefully.")
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"💥 Error: Python executable or sglang module not found.")
|
||||
logger.error(f"Ensure '{sys.executable}' is correct and sglang is installed.")
|
||||
sys.exit(1)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("🛑 SGLang server launch interrupted by user. Stopping server...")
|
||||
# Attempt to terminate the process gracefully
|
||||
if process and process.poll() is None: # Check if process exists and is running
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=5) # Wait a bit for termination
|
||||
logger.info("✅ Server terminated gracefully.")
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("⚠️ Server did not terminate gracefully, forcing kill.")
|
||||
process.kill()
|
||||
sys.exit(0) # Exit cleanly after interrupt
|
||||
except Exception as e:
|
||||
# Catch any other unexpected exceptions during launch or waiting
|
||||
logger.error(f"🚨 An unexpected error occurred: {e}")
|
||||
# Ensure process is cleaned up if it exists
|
||||
if process and process.poll() is None:
|
||||
process.kill()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Get context length from config, default to 8192
|
||||
context_len = MODEL_CONFIG.get("max_seq_length", 8192)
|
||||
|
||||
logger.info("----------------------------------------------------")
|
||||
logger.info("✨ Starting SGLang Generator Server ✨")
|
||||
logger.info(f" Model ID: {GENERATOR_MODEL_REPO_ID}")
|
||||
logger.info(f" Port: {GENERATOR_SERVER_PORT}")
|
||||
logger.info(f" Context Length: {context_len}")
|
||||
logger.info("----------------------------------------------------")
|
||||
|
||||
launch_sglang_server(
|
||||
model_id=GENERATOR_MODEL_REPO_ID,
|
||||
port=GENERATOR_SERVER_PORT,
|
||||
context_length=context_len,
|
||||
)
|
@ -0,0 +1,113 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from flashrag.config import Config
|
||||
from flashrag.utils import get_retriever
|
||||
from pydantic import BaseModel
|
||||
|
||||
from config import RETRIEVER_SERVER_PORT
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
retriever_list = []
|
||||
available_retrievers = deque()
|
||||
retriever_semaphore = None
|
||||
|
||||
|
||||
def init_retriever(args):
|
||||
global retriever_semaphore
|
||||
config = Config(args.config)
|
||||
for i in range(args.num_retriever):
|
||||
print(f"Initializing retriever {i + 1}/{args.num_retriever}")
|
||||
retriever = get_retriever(config)
|
||||
retriever_list.append(retriever)
|
||||
available_retrievers.append(i)
|
||||
# create a semaphore to limit the number of retrievers that can be used concurrently
|
||||
retriever_semaphore = asyncio.Semaphore(args.num_retriever)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
return {"status": "healthy", "retrievers": {"total": len(retriever_list), "available": len(available_retrievers)}}
|
||||
|
||||
|
||||
class QueryRequest(BaseModel):
|
||||
query: str
|
||||
top_n: int = 10
|
||||
return_score: bool = False
|
||||
|
||||
|
||||
class BatchQueryRequest(BaseModel):
|
||||
query: List[str]
|
||||
top_n: int = 10
|
||||
return_score: bool = False
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
id: str
|
||||
contents: str
|
||||
|
||||
|
||||
@app.post("/search", response_model=Union[Tuple[List[Document], List[float]], List[Document]])
|
||||
async def search(request: QueryRequest):
|
||||
query = request.query
|
||||
top_n = request.top_n
|
||||
return_score = request.return_score
|
||||
|
||||
if not query or not query.strip():
|
||||
print(f"Query content cannot be empty: {query}")
|
||||
raise HTTPException(status_code=400, detail="Query content cannot be empty")
|
||||
|
||||
async with retriever_semaphore:
|
||||
retriever_idx = available_retrievers.popleft()
|
||||
try:
|
||||
if return_score:
|
||||
results, scores = retriever_list[retriever_idx].search(query, top_n, return_score)
|
||||
return [Document(id=result["id"], contents=result["contents"]) for result in results], scores
|
||||
else:
|
||||
results = retriever_list[retriever_idx].search(query, top_n, return_score)
|
||||
return [Document(id=result["id"], contents=result["contents"]) for result in results]
|
||||
finally:
|
||||
available_retrievers.append(retriever_idx)
|
||||
|
||||
|
||||
@app.post("/batch_search", response_model=Union[List[List[Document]], Tuple[List[List[Document]], List[List[float]]]])
|
||||
async def batch_search(request: BatchQueryRequest):
|
||||
query = request.query
|
||||
top_n = request.top_n
|
||||
return_score = request.return_score
|
||||
|
||||
async with retriever_semaphore:
|
||||
retriever_idx = available_retrievers.popleft()
|
||||
try:
|
||||
if return_score:
|
||||
results, scores = retriever_list[retriever_idx].batch_search(query, top_n, return_score)
|
||||
return [
|
||||
[Document(id=result["id"], contents=result["contents"]) for result in results[i]]
|
||||
for i in range(len(results))
|
||||
], scores
|
||||
else:
|
||||
results = retriever_list[retriever_idx].batch_search(query, top_n, return_score)
|
||||
return [
|
||||
[Document(id=result["id"], contents=result["contents"]) for result in results[i]]
|
||||
for i in range(len(results))
|
||||
]
|
||||
finally:
|
||||
available_retrievers.append(retriever_idx)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, default="./retriever_config.yaml")
|
||||
parser.add_argument("--num_retriever", type=int, default=1)
|
||||
parser.add_argument("--port", type=int, default=RETRIEVER_SERVER_PORT)
|
||||
args = parser.parse_args()
|
||||
|
||||
init_retriever(args)
|
||||
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=args.port)
|
Loading…
Reference in new issue