TRL documentation
Reward Functions
Reward Functions
This module contains some useful reward functions, primarily intended for use with the GRPOTrainer and RLOOTrainer.
accuracy_reward
trl.rewards.accuracy_reward
< source >( completions: list solution: list log_extra: collections.abc.Callable[[str, list], None] | None = None **kwargs )
Parameters
- completions (
list[list[dict[str, str]]]) — List of completions to be evaluated. Each completion must be a list of one message, i.e. a dictionary containing the key"content"with the value being the text of the completion. - solution (
list[str]) — List of the raw-text solutions to the questions/problems/prompts. - log_extra (
callable, optional) — Callable to log extra columns to the completions table, provided automatically by the trainer. Defaults toNoneto allow calling the function directly outside of a trainer (e.g., for testing). - **kwargs — Additional keyword arguments. This function does not use them, but they are required in the function signature to ensure compatibility with trainers like GRPOTrainer.
Reward function that checks if the completion matches the ground truth.
- If both gold and prediction are parseable → use math verification.
- If gold is not parseable → return
Noneto skip the example.
Example:
>>> from trl.rewards import accuracy_reward
>>> solutions = [r"\frac{1}{3}", r"\frac{1}{3}"]
>>> completions = [
... [{"role": "assistant", "content": r"My answer is \boxed{\frac{1}{3}}"}],
... [{"role": "assistant", "content": r"My answer is \boxed{\frac{1}{2}}"}],
... ]
>>> accuracy_reward(completions, solutions)
[1.0, 0.0]reasoning_accuracy_reward
trl.rewards.reasoning_accuracy_reward
< source >( completions: list solution: list reasoning_delimiters: list[str] | None = None log_extra: collections.abc.Callable[[str, list], None] | None = None **kwargs )
Parameters
- completions (
list[list[dict[str, str]]]) — List of completions to be evaluated. Each completion must be a list of one message, i.e. a dictionary containing the key"content"with the value being the text of the completion. - solution (
list[str]) — List of the raw-text solutions to the questions/problems/prompts. - reasoning_delimiters (
list[str]], optional) — List of strings indicating where the reasoning content ends. The final answer is assumed to be after the last occurrence of any of these delimiters. IfNone, defaults to["</think>"]. - log_extra (
callable, optional) — Callable to log extra columns to the completions table, provided automatically by the trainer. Defaults toNoneto allow calling the function directly outside of a trainer (e.g., for testing). - **kwargs — Additional keyword arguments. This function does not use them, but they are required in the function signature to ensure compatibility with trainers like GRPOTrainer.
Reward function that removes the reasoning content and checks if the final answer matches the ground truth.
- If both gold and prediction are parseable → use math verification.
- If gold is not parseable → return
Noneto skip the example.
Example:
>>> from trl.rewards import reasoning_accuracy_reward
>>> reasoning_delimiters = ["</think>"]
>>> solutions = [r"\frac{1}{3}", r"\frac{1}{3}", r"\frac{1}{3}"]
>>> completions = [
... [
... {
... "role": "assistant",
... "content": r"<think> Reasoning content </think> The final answer is \boxed{\frac{1}{3}}",
... }
... ],
... [
... {
... "role": "assistant",
... "content": r"<think> Reasoning content </think> The final answer is \boxed{\frac{1}{2}}",
... }
... ],
... [
... {
... "role": "assistant",
... "content": r"<think> Reasoning content with partial answers \boxed{\frac{1}{3}} but no final answer",
... }
... ],
... ]
>>> reasoning_accuracy_reward(completions, solutions, reasoning_delimiters=reasoning_delimiters)
[1.0, 0.0, 0.0]think_format_reward
trl.rewards.think_format_reward
< source >( completions: list **kwargs ) → list[float]
Parameters
- completions (
list[list[dict[str, str]]]) — List of completions to be evaluated. Each completion must be a list of one message, i.e. a dictionary containing the key"content"with the value being the text of the completion. - **kwargs — Additional keyword arguments. This function does not use them, but they are required in the function signature to ensure compatibility with trainers like GRPOTrainer.
Returns
list[float]
A list of rewards, where each reward is 1.0 if the completion matches the expected format, otherwise 0.0.
Reward function that checks if the reasoning process is enclosed within "<think>" and "</think>" tags. The
function returns a reward of 1.0 if the format is correct, otherwise 0.0.
get_repetition_penalty_reward
trl.rewards.get_repetition_penalty_reward
< source >( ngram_size: int = 3 max_penalty: float = -1.0 ) → Callable
Parameters
- ngram_size (
int, optional, defaults to3) — Size of the token n-grams to consider. - max_penalty (
float, optional, defaults to-1.0) — Most negative penalty, applied to a fully repetitive completion. Must be non-positive.
Returns
Callable
A reward function that takes a list of completion token ids and returns a list of penalties (each in
[max_penalty, 0.0]).
Reward function that penalizes repeated n-grams in a completion, used to discourage degenerate, repetitive text (a common failure mode and reward-hacking strategy when length- or format-shaping rewards are used). Reference: Appendix C.2 of the “Demystifying Long Chain-of-Thought Reasoning” paper (https://huggingface.co/papers/2502.03373).
The penalty is proportional to the fraction of repeated n-grams in the completion:
where is max_penalty. A completion with no repeated n-gram gets a reward of 0.0, while a fully repetitive
one approaches max_penalty. The n-grams are computed over the completion token ids (the paper applies the penalty
to repeated tokens), so the reward is tokenizer-defined and language-agnostic. Completions with fewer than
ngram_size tokens get a reward of 0.0.
get_soft_overlong_punishment
trl.rewards.get_soft_overlong_punishment
< source >( max_completion_len: int soft_punish_cache: int )
Reward function that penalizes overlong completions. It is used to penalize overlong completions, but not to reward shorter completions. Reference: Eq. (13) from the DAPO paper (https://huggingface.co/papers/2503.14476)
Example:
from trl.rewards import get_soft_overlong_punishment
soft_overlong_punishment = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20)
completion_ids = [[1] * 90] # simulating a completion with 90 tokens. 90 is between 80 and 100.
rewards = soft_overlong_punishment(completion_ids)
print(rewards) # [-0.5]