From 9738b80353c9be77e4b0ac2163274d29354dfd2a Mon Sep 17 00:00:00 2001 From: thinhlpg Date: Tue, 15 Apr 2025 05:04:33 +0000 Subject: [PATCH] feat: update max generations and output length in evaluation scripts, add memory fraction to server launch --- .gitignore | 2 +- scripts/evaluation/run_eval.py | 15 ++++++--------- scripts/serving/serve_generator.py | 2 ++ 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index 1a2d02e..be40a92 100644 --- a/.gitignore +++ b/.gitignore @@ -18,7 +18,7 @@ logs/ *.code-workspace data/ .gradio/ -output_* +output* # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/scripts/evaluation/run_eval.py b/scripts/evaluation/run_eval.py index a0ccb16..7b6810d 100644 --- a/scripts/evaluation/run_eval.py +++ b/scripts/evaluation/run_eval.py @@ -3,7 +3,6 @@ import json import os import re import time -from concurrent.futures import ThreadPoolExecutor, as_completed from copy import deepcopy from datetime import datetime from functools import wraps @@ -15,12 +14,10 @@ from flashrag.generator.generator import BaseGenerator from flashrag.pipeline import BasicPipeline from flashrag.retriever.retriever import BaseTextRetriever from flashrag.utils import get_dataset -from tqdm import tqdm from transformers import AutoTokenizer from config import logger from src.agent import Agent, AgenticOutputs -from src.prompts import build_user_prompt, get_system_prompt from src.tokenizer_adapter import LlamaTokenizerAdapter, R1DistilTokenizerAdapter @@ -157,7 +154,7 @@ class ReSearchPipeline(BasicPipeline): adapter = R1DistilTokenizerAdapter() logger.info(f"🔩 Using Tokenizer Adapter: {type(adapter).__name__}") - def retriever_search(query: str, return_type=str, results: int = 2): + def retriever_search(query: str, return_type=str, results: int = 5): try: search_results = self.retriever._search(query, num=results) return self.format_search_results(search_results) @@ -216,8 +213,8 @@ class ReSearchPipeline(BasicPipeline): logger.error("Ensure dataset items have a 'question' key or attribute.") return dataset - agent_max_generations = getattr(self.config, "agent_max_generations", 20) - generator_max_output_len = getattr(self.config, "generator_max_output_len", 1024) + agent_max_generations = getattr(self.config, "agent_max_generations", 32) + generator_max_output_len = getattr(self.config, "generator_max_output_len", 24576) try: logger.info(f"🤖 Running agent inference for {len(questions)} questions...") @@ -422,7 +419,7 @@ class SGLRemoteGenerator(BaseGenerator): except requests.exceptions.RequestException as e: logger.error(f"Network error during generation: {str(e)}", exc_info=True) raise - except json.JSONDecodeError as e: + except json.JSONDecodeError: response_text = "Unknown (error occurred before response object assignment)" if "response" in locals() and hasattr(response, "text"): response_text = response.text[:500] @@ -561,8 +558,8 @@ def research(args: argparse.Namespace, config: Config): logger.error(f"Failed to initialize ReSearchPipeline: {e}", exc_info=True) return - agent_max_generations = getattr(config, "agent_max_generations", 20) - generator_max_output_len = getattr(config, "generator_max_output_len", 1024) + agent_max_generations = getattr(config, "agent_max_generations", 32) + generator_max_output_len = getattr(config, "generator_max_output_len", 24576) try: logger.info("🏃 Starting pipeline run...") diff --git a/scripts/serving/serve_generator.py b/scripts/serving/serve_generator.py index 6901929..571a553 100644 --- a/scripts/serving/serve_generator.py +++ b/scripts/serving/serve_generator.py @@ -56,6 +56,8 @@ def launch_sglang_server( host, "--port", str(port), + "--mem-fraction-static", + "0.5", "--trust-remote-code", # Recommended by SGLang for stability sometimes "--disable-overlap",