From bec864038ba4efb6a6dd771bb6604c9543cb46ac Mon Sep 17 00:00:00 2001 From: thinhlpg Date: Mon, 14 Apr 2025 09:09:01 +0000 Subject: [PATCH] feat: increase max tokens and new tokens in evaluation scripts --- scripts/eval_base.py | 1 + scripts/eval_lora.py | 7 ++++--- src/evaluation.py | 6 ++++-- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/scripts/eval_base.py b/scripts/eval_base.py index da1a1c0..87e0113 100644 --- a/scripts/eval_base.py +++ b/scripts/eval_base.py @@ -113,6 +113,7 @@ def main(): output_file=output_file, debug_file=debug_file, max_generations=32, + max_new_tokens=4096 * 6, ) logger.info("✨ Evaluation completed!") diff --git a/scripts/eval_lora.py b/scripts/eval_lora.py index b96e0d5..bbc2e04 100644 --- a/scripts/eval_lora.py +++ b/scripts/eval_lora.py @@ -41,7 +41,7 @@ def main(): # Setup model with LoRA support using config values model, tokenizer = FastLanguageModel.from_pretrained( model_name=args.model_name, - max_seq_length=4096 * 2, + max_seq_length=4096 * 6, load_in_4bit=True, fast_inference=True, max_lora_rank=lora_config["r"], # Use rank from config @@ -64,14 +64,14 @@ def main(): sampling_params = SamplingParams( temperature=args.temperature, top_p=0.95, - max_tokens=4096 * 2, + max_tokens=4096 * 6, ) # Setup verifier with lower temperature verifier_params = SamplingParams( temperature=0.1, # Lower temperature for more consistent verification top_p=0.95, - max_tokens=4096, + max_tokens=4096 * 6, ) def generate_fn(inputs): @@ -136,6 +136,7 @@ def main(): output_file=output_file, debug_file=debug_file, max_generations=32, + max_new_tokens=4096 * 6, ) logger.info("✨ Evaluation completed!") diff --git a/src/evaluation.py b/src/evaluation.py index 7559cd3..cf3e503 100644 --- a/src/evaluation.py +++ b/src/evaluation.py @@ -148,7 +148,9 @@ def check_student_answers( return results -def run_eval(generate_fn, verify_fn, tokenizer, max_generations=20, output_file=None, debug_file=None): +def run_eval( + generate_fn, verify_fn, tokenizer, max_generations=32, max_new_tokens=4096 * 6, output_file=None, debug_file=None +): """ Run evaluation on the test dataset and return results. @@ -179,7 +181,7 @@ def run_eval(generate_fn, verify_fn, tokenizer, max_generations=20, output_file= adapter = R1DistilTokenizerAdapter() agent = Agent(adapter) - agentic_outputs = agent.run_agent(generate_fn, tokenizer, questions, max_generations) + agentic_outputs = agent.run_agent(generate_fn, tokenizer, questions, max_generations, max_new_tokens) full_chat_states = agentic_outputs.full_chat_states final_responses = agentic_outputs.final_response_str rewards = verify_fn(questions, full_chat_states, answer=test_dataset["answer"])