From 1a18cd7bfdffe9971714887012b7a9b4e7725e81 Mon Sep 17 00:00:00 2001 From: thinhlpg Date: Fri, 4 Apr 2025 10:11:23 +0700 Subject: [PATCH] feat: update training and evaluation configurations (editable agent generation scripts) Increased max_generations parameter in agentic_generate and run_eval functions for improved output flexibility. --- src/evaluation.py | 4 ++-- train_grpo.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/evaluation.py b/src/evaluation.py index e63711d..750b7aa 100644 --- a/src/evaluation.py +++ b/src/evaluation.py @@ -148,7 +148,7 @@ def check_student_answers( return results -def run_eval(generate_fn, verify_fn, tokenizer, output_file=None, debug_file=None): +def run_eval(generate_fn, verify_fn, tokenizer, max_generations=20, output_file=None, debug_file=None): """ Run evaluation on the test dataset and return results. @@ -175,7 +175,7 @@ def run_eval(generate_fn, verify_fn, tokenizer, output_file=None, debug_file=Non adapter = R1DistilTokenizerAdapter() agent = Agent(adapter) - agentic_outputs = agent.run_agent(generate_fn, tokenizer, questions) + agentic_outputs = agent.run_agent(generate_fn, tokenizer, questions, max_generations) full_chat_states = agentic_outputs.full_chat_states final_responses = agentic_outputs.final_response_str rewards = verify_fn(questions, full_chat_states, answer=test_dataset["answer"]) diff --git a/train_grpo.py b/train_grpo.py index 7c7ada7..182a8a3 100644 --- a/train_grpo.py +++ b/train_grpo.py @@ -76,6 +76,7 @@ training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig( bf16=is_bfloat16_supported(), fp16=not is_bfloat16_supported(), output_dir=OUTPUT_DIR, + reward_weights=[4.0, 2.0, 1.0, 1.0, 1.0, 1.0], # report_to="tensorboard", # ❓ Does't have billions of tensorboard files if set report to right here ) @@ -84,7 +85,7 @@ training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig( def agentic_generate( prompts: list, generate_fn, - max_generations: int = 10, + max_generations: int = 20, ): # Create agent with appropriate adapter based on tokenizer tokenizer_name = tokenizer.name_or_path.lower()