20 KiB
Reward functions
This note is a collection of stolen reward functions and tips from other projects.
- NEED SOMETHING THAT MAKE THE MODEL WORK HARDER!!!
- Goal: design reward functions (Search Task!) for DeepSearch's GRPO trainings (likely to be exact match) (Try the suggestion by unsloth below, lol)
-
You can refer to the examples below. You can input your generations into an LLM like ChatGPT 4o or Llama 3.1 (8B) and design a reward function and verifier to evaluate it. For example, feed your generations into a LLM of your choice and set a rule: "If the answer sounds too robotic, deduct 3 points." This helps refine outputs based on quality criteria
- Label studio suggest consult domain experts -> ask the LLM to be search engine expert??
- Starting from the default of AutoDiact should be good enough, then figure out big brain moves from there
-
Implementation Phases
- 1.Just keep the default ones from AutoDidact and add the Exact Match Idea
- Oh they only use 2 reward functions "reward_correctness" and "reward_formatting"
- 2. Add more if needed.
Psuedo code
Get a sense of Reward functions
-
https://github.com/kubernetes-bad/reward-composer
- Reward Composer is a collection of simple building blocks for making your perfect reward function for Reinforcement Learning training of language models... It's like Lego for GRPO.
-
https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb
- Really minimalist and simple grpo training script (only 171 lines :O)
-
Example form unsloth's blog https://docs.unsloth.ai/basics/reasoning-grpo-and-rl#reward-function-examples
-
You can reuse data across multiple epochs. - What does this mean 👀?
-
-
- Factual Accuracy: Checking whether the output contains verifiable facts.
- Logical Consistency: Ensuring that arguments or narratives are internally consistent. Ensure solving propositional logic reasoning problems
- Exact Match and Heuristics: Use deterministic rules to check correctness (e.g., exact match in math answers, passing test cases in code, matching the predefined categories or taxonomy etc.)
-
Designing a verifiable reward function requires expert knowledge, domain expertise, and structured data interfaces - Can I just LLM Roleplaying search engine expert? 👀
- Multi-Level Scoring: Implement tiered scoring mechanisms to reward partial correctness where applicable. (cool, might try this)
-
- Validate the Reward Model Based on Generated Examples Run Controlled Tests: Generate model outputs and measure how well the reward function distinguishes correct from incorrect responses. Evaluate for Robustness: Ensure the function avoids penalizing correct responses due to formatting issues or minor variations. A/B Testing with RL Agents: Compare performance between models trained with and without the verifiable reward function.
Reward Function vs Verifier
Stolen note from unsloth's docs:
Component | Purpose | Characteristics | Examples |
---|---|---|---|
Verifier | Determines correctness | - No numerical scoring - Binary correct/incorrect judgment |
- Checks if "2+2=5" is wrong - Executes code to validate syntax/logic |
Reward Function | Assigns numerical scores | - Converts verification to numbers - Can include multiple criteria |
- Wrong answer: -1 or -2 - Correct answer: +1 or +2 - Penalties for length/readability |
Key Differences | - Verifier: checks correctness without scoring - Reward Function: assigns scores without necessarily verifying - Reward Function can use a Verifier, but they're distinct components |
Idea examples
Note taken from unsloth's docs.
Example #1: Simple Arithmetic Task
- Question: "2 + 2"
- Answer: "4"
- Reward Function 1:
- If a number is detected → +1
- If no number is detected → -1
Example #2: Email Automation Task
- Question: Inbound email
- Answer: Outbound email
- Reward Functions:
- If the answer contains a required keyword → +1
- If the answer exactly matches the ideal response → +1
- If the response is too long → -1
- If the recipient's name is included → +1
- If a signature block (phone, email, address) is present → +1
Code Examples
- Below is a code snippet from @unslothai sample notebook, which is taken from @willccbb's gist
# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
responses = [completion[0]["content"] for completion in completions]
q = prompts[0][-1]["content"]
extracted_responses = [extract_xml_answer(r) for r in responses]
print(
"-" * 20,
f"Question:\n{q}",
f"\nAnswer:\n{answer[0]}",
f"\nResponse:\n{responses[0]}",
f"\nExtracted:\n{extracted_responses[0]}",
)
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def int_reward_func(completions, **kwargs) -> list[float]:
responses = [completion[0]["content"] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
def strict_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def soft_format_reward_func(completions, **kwargs) -> list[float]:
"""Reward function that checks if the completion has a specific format."""
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
responses = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, r) for r in responses]
return [0.5 if match else 0.0 for match in matches]
def count_xml(text) -> float:
count = 0.0
if text.count("<reasoning>\n") == 1:
count += 0.125
if text.count("\n</reasoning>\n") == 1:
count += 0.125
if text.count("\n<answer>\n") == 1:
count += 0.125
count -= len(text.split("\n</answer>\n")[-1]) * 0.001
if text.count("\n</answer>") == 1:
count += 0.125
count -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001
return count
def xmlcount_reward_func(completions, **kwargs) -> list[float]:
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]
...
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[ # Personal note: didn't expect this be so simple to implement @@
xmlcount_reward_func,
soft_format_reward_func,
strict_format_reward_func,
int_reward_func,
correctness_reward_func,
],
args=training_args,
train_dataset=dataset,
)
trainer.train()
-
Just curious, how did the team implemented the reward functions for Alphamaze?
-
Below is from Alphamaze's repo
-
We designed a reward function 3 components. Correctness Reward (+0.2 per solution step): This reward is scaled according to the number of steps in the maze solution. Each valid movement step adds 0.2 points to the total score. For example, a solution requiring 4 steps earns a reward of 0.2 × 4 = 0.8 points, incentivizing both accuracy and efficiency in navigation. Integrity Reward (+0.5): This reward is given for each valid movement token (<|up|>, <|down|>, <|left|>, <|right|>) in the predicted sequence, encouraging the generation of meaningful and valid movement steps.
-
Thinking Reward (+0.25): This reward is given for correctly using the tag in the output, ensuring completeness and consistency in the reasoning format. These reward components were weighted to prioritize correctness while also encouraging valid movement sequences and proper reasoning formatting with tag. We adapted the Group Relative Policy Optimization (GRPO) algorithm, as employed in DeepSeek-R1 [Guo et al., 2025], to perform reinforcement learning. GRPO estimates advantages based on relative group scores, offering computational efficiency compared to critic-based methods.
-
def xmlcount_reward_func(completions, **kwargs) -> List[float]:
"""
Reward function based on proper XML tag usage.
Args:
completions: Model completions
Returns:
List of reward scores
"""
contents = [completion[0]["content"] for completion in completions]
return [count_xml(c) for c in contents]
def int_reward_func(completions, **kwargs) -> List[float]:
"""
Reward function that checks if responses contain valid direction tokens.
Args:
completions: Model completions
Returns:
List of reward scores
"""
allowed_tokens = {"<|up|>", "<|down|>", "<|right|>", "<|left|>"}
responses = [completion[0]['content'] for completion in completions]
extracted_responses = [extract_xml_answer(r) for r in responses]
def correctness_reward_func(prompts, completions, answer, **kwargs) -> List[float]:
"""
Reward function that checks correctness of answers.
Args:
prompts: Input prompts
completions: Model completions
answer: Ground truth answers
Returns:
List of reward scores
"""
rewards = []
responses = [completion[0]['content'] for completion in completions]
q = prompts[0][-1]['content']
extracted_responses = [extract_xml_answer(r) for r in responses]
logger.debug('-'*20)
logger.debug(f"Question:\n{q}")
logger.debug(f"\nAnswer:\n{answer[0]}")
logger.debug(f"\nResponse:\n{responses[0]}")
logger.debug(f"\nExtracted:\n{extracted_responses[0]}")
for r, a in zip(extracted_responses, answer):
if r == a:
direction = r.split("|><|")
rewards.append(len(direction)*0.2)
else:
rewards.append(0.0)
return rewards
# def strict_format_reward_func(completions, **kwargs) -> List[float]:
# """
# Reward function that checks if completions strictly follow the required format.
# Args:
# completions: Model completions
# Returns:
# List of reward scores
# """
# pattern = r"^<think>\n.*?\n</think>\n\n.*?\n$"
# responses = [completion[0]["content"] for completion in completions]
# matches = [re.match(pattern, r, re.DOTALL) for r in responses]
# return [0.5 if match else 0.0 for match in matches]
# def soft_format_reward_func(completions, **kwargs) -> List[float]:
# """
# Reward function that checks if completions loosely follow the required format.
# Args:
# completions: Model completions
# Returns:
# List of reward scores
# """
# pattern = r"<think>.*?</think>\s*.*?"
# responses = [completion[0]["content"] for completion in completions]
# matches = [re.match(pattern, r, re.DOTALL) for r in responses]
# return [0.5 if match else 0.0 for match in matches]
...
reward_funcs=[
xmlcount_reward_func,
# soft_format_reward_func,
# strict_format_reward_func,
int_reward_func,
correctness_reward_func,
],
Comparison of Alphamaze's reward functions and unsloth's
Feature | Unsloth Example | AlphaMaze | Similarities | Differences |
---|---|---|---|---|
Overall Purpose | To evaluate and score the quality of model-generated text based on various criteria (format, correctness, content). | Same as Unsloth. | Both aim to provide numerical rewards for model outputs based on defined criteria. | AlphaMaze appears more focused on a specific maze-solving task (directions in the answer), while Unsloth's examples are more general, including evaluating whether a number prediction can be cast to integer . |
Function Structure | Functions generally take completions (and sometimes prompts , answer ) as input. Return a list of floats (rewards). |
Same as Unsloth. | Both use functions that take model outputs (and sometimes inputs) and return lists of reward scores. | AlphaMaze's correctness_reward_func calculates a reward based on the length of the correct answer (number of directions), while Unsloth's gives a fixed reward (2.0) for a correct answer, and 0 otherwise. |
Reward Types | - correctness_reward_func : Checks if the extracted answer matches the ground truth. Binary reward (2.0 or 0.0).- int_reward_func : Checks if extracted answer is a digit. Binary reward (0.5 or 0.0).- strict_format_reward_func , soft_format_reward_func : Check for specific XML-like formatting using regular expressions. Binary reward (0.5 or 0.0).- xmlcount_reward_func : Counts XML tags, providing a fractional reward based on tag presence and penalizing trailing text. |
- correctness_reward_func : Checks if extracted answer matches ground truth. Reward is proportional to answer length (0.2 per direction).- int_reward_func : Checks if the answer consists of allowed tokens. The implementation in this code is not complete. - xmlcount_reward_func : Same as Unsloth's.- strict_format_reward_func (commented out): Checks for a specific format using regex.- soft_format_reward_func (commented out): Checks for a looser format. |
- Both have correctness_reward_func , int_reward_func , xmlcount_reward_func (though implemented slightly differently).- Both use regular expressions for format checking. |
- Unsloth uses a simpler binary reward for correctness. AlphaMaze uses a length-based reward. - Unsloth's int_reward_func check if castable to integer, AlphaMaze's intends to check for allowed direction tokens (but the implementation is not finished).- AlphaMaze's formatting functions are commented out. |
correctness_reward_func |
Compares extracted answer to ground truth. Prints debugging information. Returns 2.0 for correct, 0.0 otherwise. | Compares extracted answer to ground truth, calculates reward based on the length of the correct answer (number of direction steps, 0.2 per step). Logs debugging information. | Both compare the extracted answer to the ground truth answer and print/log debugging information. | - Unsloth returns a fixed reward (2.0) for a correct answer. - AlphaMaze's reward is proportional to the length of the correct answer (0.2 per direction). |
int_reward_func |
Checks if the extracted response isdigit() . Returns 0.5 if true, 0.0 otherwise. |
Intended to check if the response contains allowed direction tokens (`< | up | >, < |
xmlcount_reward_func |
Same implementation in both. Counts opening/closing tags, penalizes extra text. | Same implementation in both. | Identical implementation. | None. |
Format Checking | Uses strict_format_reward_func and soft_format_reward_func with different regular expressions. |
Has strict_format_reward_func and soft_format_reward_func (commented out) with different regular expressions. |
Both use regular expressions to check for specific formatting patterns. | - Unsloth's format checks look for <reasoning> and <answer> tags.- AlphaMaze's (commented out) checks look for <think> tags and a general structure.- Unsloth's are active; AlphaMaze's are commented out. |
Extracted Answer | Both use an extract_xml_answer function (not shown in the provided snippets, but assumed to be defined elsewhere). |
Same as Unsloth. | Both rely on an external function to extract the relevant part of the response for evaluation. | We don't know the exact implementation of extract_xml_answer , so there might be subtle differences. However, the use is the same. |