feat: update training and evaluation configurations (editable agent generation scripts)

Increased max_generations parameter in agentic_generate and run_eval functions for improved output flexibility.
main
thinhlpg 1 month ago
parent 77f121662f
commit 1a18cd7bfd

@ -148,7 +148,7 @@ def check_student_answers(
return results return results
def run_eval(generate_fn, verify_fn, tokenizer, output_file=None, debug_file=None): def run_eval(generate_fn, verify_fn, tokenizer, max_generations=20, output_file=None, debug_file=None):
""" """
Run evaluation on the test dataset and return results. Run evaluation on the test dataset and return results.
@ -175,7 +175,7 @@ def run_eval(generate_fn, verify_fn, tokenizer, output_file=None, debug_file=Non
adapter = R1DistilTokenizerAdapter() adapter = R1DistilTokenizerAdapter()
agent = Agent(adapter) agent = Agent(adapter)
agentic_outputs = agent.run_agent(generate_fn, tokenizer, questions) agentic_outputs = agent.run_agent(generate_fn, tokenizer, questions, max_generations)
full_chat_states = agentic_outputs.full_chat_states full_chat_states = agentic_outputs.full_chat_states
final_responses = agentic_outputs.final_response_str final_responses = agentic_outputs.final_response_str
rewards = verify_fn(questions, full_chat_states, answer=test_dataset["answer"]) rewards = verify_fn(questions, full_chat_states, answer=test_dataset["answer"])

@ -76,6 +76,7 @@ training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(
bf16=is_bfloat16_supported(), bf16=is_bfloat16_supported(),
fp16=not is_bfloat16_supported(), fp16=not is_bfloat16_supported(),
output_dir=OUTPUT_DIR, output_dir=OUTPUT_DIR,
reward_weights=[4.0, 2.0, 1.0, 1.0, 1.0, 1.0],
# report_to="tensorboard", # ❓ Does't have billions of tensorboard files if set report to right here # report_to="tensorboard", # ❓ Does't have billions of tensorboard files if set report to right here
) )
@ -84,7 +85,7 @@ training_args = UnslothGRPOTrainerTemp.UnslothGRPOConfig(
def agentic_generate( def agentic_generate(
prompts: list, prompts: list,
generate_fn, generate_fn,
max_generations: int = 10, max_generations: int = 20,
): ):
# Create agent with appropriate adapter based on tokenizer # Create agent with appropriate adapter based on tokenizer
tokenizer_name = tokenizer.name_or_path.lower() tokenizer_name = tokenizer.name_or_path.lower()

Loading…
Cancel
Save