feat: update max generations and output length in evaluation scripts, add memory fraction to server launch

main
thinhlpg 4 weeks ago
parent 7ee65269fb
commit 9738b80353

2
.gitignore vendored

@ -18,7 +18,7 @@ logs/
*.code-workspace
data/
.gradio/
output_*
output*
# Byte-compiled / optimized / DLL files
__pycache__/

@ -3,7 +3,6 @@ import json
import os
import re
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from copy import deepcopy
from datetime import datetime
from functools import wraps
@ -15,12 +14,10 @@ from flashrag.generator.generator import BaseGenerator
from flashrag.pipeline import BasicPipeline
from flashrag.retriever.retriever import BaseTextRetriever
from flashrag.utils import get_dataset
from tqdm import tqdm
from transformers import AutoTokenizer
from config import logger
from src.agent import Agent, AgenticOutputs
from src.prompts import build_user_prompt, get_system_prompt
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
@ -157,7 +154,7 @@ class ReSearchPipeline(BasicPipeline):
adapter = R1DistilTokenizerAdapter()
logger.info(f"🔩 Using Tokenizer Adapter: {type(adapter).__name__}")
def retriever_search(query: str, return_type=str, results: int = 2):
def retriever_search(query: str, return_type=str, results: int = 5):
try:
search_results = self.retriever._search(query, num=results)
return self.format_search_results(search_results)
@ -216,8 +213,8 @@ class ReSearchPipeline(BasicPipeline):
logger.error("Ensure dataset items have a 'question' key or attribute.")
return dataset
agent_max_generations = getattr(self.config, "agent_max_generations", 20)
generator_max_output_len = getattr(self.config, "generator_max_output_len", 1024)
agent_max_generations = getattr(self.config, "agent_max_generations", 32)
generator_max_output_len = getattr(self.config, "generator_max_output_len", 24576)
try:
logger.info(f"🤖 Running agent inference for {len(questions)} questions...")
@ -422,7 +419,7 @@ class SGLRemoteGenerator(BaseGenerator):
except requests.exceptions.RequestException as e:
logger.error(f"Network error during generation: {str(e)}", exc_info=True)
raise
except json.JSONDecodeError as e:
except json.JSONDecodeError:
response_text = "Unknown (error occurred before response object assignment)"
if "response" in locals() and hasattr(response, "text"):
response_text = response.text[:500]
@ -561,8 +558,8 @@ def research(args: argparse.Namespace, config: Config):
logger.error(f"Failed to initialize ReSearchPipeline: {e}", exc_info=True)
return
agent_max_generations = getattr(config, "agent_max_generations", 20)
generator_max_output_len = getattr(config, "generator_max_output_len", 1024)
agent_max_generations = getattr(config, "agent_max_generations", 32)
generator_max_output_len = getattr(config, "generator_max_output_len", 24576)
try:
logger.info("🏃 Starting pipeline run...")

@ -56,6 +56,8 @@ def launch_sglang_server(
host,
"--port",
str(port),
"--mem-fraction-static",
"0.5",
"--trust-remote-code",
# Recommended by SGLang for stability sometimes
"--disable-overlap",

Loading…
Cancel
Save