feat: update max generations and output length in evaluation scripts, add memory fraction to server launch

main
thinhlpg 4 weeks ago
parent 7ee65269fb
commit 9738b80353

2
.gitignore vendored

@ -18,7 +18,7 @@ logs/
*.code-workspace *.code-workspace
data/ data/
.gradio/ .gradio/
output_* output*
# Byte-compiled / optimized / DLL files # Byte-compiled / optimized / DLL files
__pycache__/ __pycache__/

@ -3,7 +3,6 @@ import json
import os import os
import re import re
import time import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from copy import deepcopy from copy import deepcopy
from datetime import datetime from datetime import datetime
from functools import wraps from functools import wraps
@ -15,12 +14,10 @@ from flashrag.generator.generator import BaseGenerator
from flashrag.pipeline import BasicPipeline from flashrag.pipeline import BasicPipeline
from flashrag.retriever.retriever import BaseTextRetriever from flashrag.retriever.retriever import BaseTextRetriever
from flashrag.utils import get_dataset from flashrag.utils import get_dataset
from tqdm import tqdm
from transformers import AutoTokenizer from transformers import AutoTokenizer
from config import logger from config import logger
from src.agent import Agent, AgenticOutputs from src.agent import Agent, AgenticOutputs
from src.prompts import build_user_prompt, get_system_prompt
from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter
@ -157,7 +154,7 @@ class ReSearchPipeline(BasicPipeline):
adapter = R1DistilTokenizerAdapter() adapter = R1DistilTokenizerAdapter()
logger.info(f"🔩 Using Tokenizer Adapter: {type(adapter).__name__}") logger.info(f"🔩 Using Tokenizer Adapter: {type(adapter).__name__}")
def retriever_search(query: str, return_type=str, results: int = 2): def retriever_search(query: str, return_type=str, results: int = 5):
try: try:
search_results = self.retriever._search(query, num=results) search_results = self.retriever._search(query, num=results)
return self.format_search_results(search_results) return self.format_search_results(search_results)
@ -216,8 +213,8 @@ class ReSearchPipeline(BasicPipeline):
logger.error("Ensure dataset items have a 'question' key or attribute.") logger.error("Ensure dataset items have a 'question' key or attribute.")
return dataset return dataset
agent_max_generations = getattr(self.config, "agent_max_generations", 20) agent_max_generations = getattr(self.config, "agent_max_generations", 32)
generator_max_output_len = getattr(self.config, "generator_max_output_len", 1024) generator_max_output_len = getattr(self.config, "generator_max_output_len", 24576)
try: try:
logger.info(f"🤖 Running agent inference for {len(questions)} questions...") logger.info(f"🤖 Running agent inference for {len(questions)} questions...")
@ -422,7 +419,7 @@ class SGLRemoteGenerator(BaseGenerator):
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
logger.error(f"Network error during generation: {str(e)}", exc_info=True) logger.error(f"Network error during generation: {str(e)}", exc_info=True)
raise raise
except json.JSONDecodeError as e: except json.JSONDecodeError:
response_text = "Unknown (error occurred before response object assignment)" response_text = "Unknown (error occurred before response object assignment)"
if "response" in locals() and hasattr(response, "text"): if "response" in locals() and hasattr(response, "text"):
response_text = response.text[:500] response_text = response.text[:500]
@ -561,8 +558,8 @@ def research(args: argparse.Namespace, config: Config):
logger.error(f"Failed to initialize ReSearchPipeline: {e}", exc_info=True) logger.error(f"Failed to initialize ReSearchPipeline: {e}", exc_info=True)
return return
agent_max_generations = getattr(config, "agent_max_generations", 20) agent_max_generations = getattr(config, "agent_max_generations", 32)
generator_max_output_len = getattr(config, "generator_max_output_len", 1024) generator_max_output_len = getattr(config, "generator_max_output_len", 24576)
try: try:
logger.info("🏃 Starting pipeline run...") logger.info("🏃 Starting pipeline run...")

@ -56,6 +56,8 @@ def launch_sglang_server(
host, host,
"--port", "--port",
str(port), str(port),
"--mem-fraction-static",
"0.5",
"--trust-remote-code", "--trust-remote-code",
# Recommended by SGLang for stability sometimes # Recommended by SGLang for stability sometimes
"--disable-overlap", "--disable-overlap",

Loading…
Cancel
Save