You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
200 lines
9.5 KiB
200 lines
9.5 KiB
# Debug training grpo for r1 distil
|
|
|
|
- I want to be able to continue to finetune the model from r1 distil checkpoints
|
|
- The errors also occurred in normal Qwen 2.5 1.5B Instruct
|
|
- The root cause is that the mask and the ids have different length, which is caused by custom mask logic only made for llama architecture.
|
|
|
|
## Debug strategy
|
|
|
|
Debugging Strategy:
|
|
The goal is to ensure that for every chat state i, the length of response_toks[i] is exactly the same as the length of response_masks[i] after all processing (slicing and truncation) within the final loop of run_agent.
|
|
|
|
## FOUND IT
|
|
|
|
```python
|
|
print(f" prompt_inputs {i} len before padding: {len(prompt_inputs[i])}")
|
|
print(f" completion_ids {i} len before padding: {len(completion_ids[i])}")
|
|
print(f" completion_mask {i} len before padding: {len(completion_mask[i])}")
|
|
prompt_ids = pad(
|
|
prompt_inputs,
|
|
padding_value=self.processing_class.pad_token_id,
|
|
padding_side="left",
|
|
).to(device)
|
|
completion_mask = pad(
|
|
completion_mask,
|
|
padding_value=0,
|
|
padding_side="right",
|
|
).to(device)
|
|
# print length after padding
|
|
for i in range(len(prompt_inputs)):
|
|
print(f" prompt_ids {i} len after padding: {len(prompt_ids[i])}")
|
|
print(f" completion_ids {i} len after padding: {len(completion_ids[i])}")
|
|
print(f" completion_mask {i} len after padding: {len(completion_mask[i])}")
|
|
```
|
|
|
|
- Deepseek R1 (the pattern is mask = id + 2, then magically turn into 1025?)
|
|
|
|
```bash
|
|
prompt_inputs 0 len before padding: 214
|
|
completion_ids 0 len before padding: 99
|
|
completion_mask 0 len before padding: 101
|
|
prompt_inputs 1 len before padding: 214
|
|
completion_ids 1 len before padding: 312
|
|
completion_mask 1 len before padding: 314
|
|
prompt_inputs 2 len before padding: 214
|
|
completion_ids 2 len before padding: 296
|
|
completion_mask 2 len before padding: 298
|
|
prompt_inputs 3 len before padding: 214
|
|
completion_ids 3 len before padding: 270
|
|
completion_mask 3 len before padding: 272
|
|
prompt_inputs 4 len before padding: 214
|
|
completion_ids 4 len before padding: 1024
|
|
completion_mask 4 len before padding: 1025
|
|
prompt_inputs 5 len before padding: 214
|
|
completion_ids 5 len before padding: 71
|
|
completion_mask 5 len before padding: 72
|
|
prompt_inputs 6 len before padding: 214
|
|
completion_ids 6 len before padding: 76
|
|
completion_mask 6 len before padding: 78
|
|
prompt_inputs 7 len before padding: 214
|
|
completion_ids 7 len before padding: 1024
|
|
completion_mask 7 len before padding: 1025
|
|
prompt_ids 0 len after padding: 214
|
|
completion_ids 0 len after padding: 99
|
|
completion_mask 0 len after padding: 1025
|
|
prompt_ids 1 len after padding: 214
|
|
completion_ids 1 len after padding: 312
|
|
completion_mask 1 len after padding: 1025
|
|
prompt_ids 2 len after padding: 214
|
|
completion_ids 2 len after padding: 296
|
|
completion_mask 2 len after padding: 1025
|
|
prompt_ids 3 len after padding: 214
|
|
completion_ids 3 len after padding: 270
|
|
completion_mask 3 len after padding: 1025
|
|
prompt_ids 4 len after padding: 214
|
|
completion_ids 4 len after padding: 1024
|
|
completion_mask 4 len after padding: 1025
|
|
prompt_ids 5 len after padding: 214
|
|
completion_ids 5 len after padding: 71
|
|
completion_mask 5 len after padding: 1025
|
|
prompt_ids 6 len after padding: 214
|
|
completion_ids 6 len after padding: 76
|
|
completion_mask 6 len after padding: 1025
|
|
prompt_ids 7 len after padding: 214
|
|
completion_ids 7 len after padding: 1024
|
|
completion_mask 7 len after padding: 1025
|
|
```
|
|
|
|
- and this is llama
|
|
|
|
```bash
|
|
prompt_inputs 0 len before padding: 240
|
|
completion_ids 0 len before padding: 572
|
|
completion_mask 0 len before padding: 572
|
|
prompt_inputs 1 len before padding: 240
|
|
completion_ids 1 len before padding: 323
|
|
completion_mask 1 len before padding: 323
|
|
prompt_inputs 2 len before padding: 240
|
|
completion_ids 2 len before padding: 58
|
|
completion_mask 2 len before padding: 58
|
|
prompt_inputs 3 len before padding: 240
|
|
completion_ids 3 len before padding: 61
|
|
completion_mask 3 len before padding: 61
|
|
prompt_inputs 4 len before padding: 240
|
|
completion_ids 4 len before padding: 292
|
|
completion_mask 4 len before padding: 292
|
|
prompt_inputs 5 len before padding: 240
|
|
completion_ids 5 len before padding: 588
|
|
completion_mask 5 len before padding: 588
|
|
prompt_inputs 6 len before padding: 240
|
|
completion_ids 6 len before padding: 617
|
|
completion_mask 6 len before padding: 617
|
|
prompt_inputs 7 len before padding: 240
|
|
completion_ids 7 len before padding: 62
|
|
completion_mask 7 len before padding: 62
|
|
prompt_ids 0 len after padding: 240
|
|
completion_ids 0 len after padding: 572
|
|
completion_mask 0 len after padding: 617
|
|
prompt_ids 1 len after padding: 240
|
|
completion_ids 1 len after padding: 323
|
|
completion_mask 1 len after padding: 617
|
|
prompt_ids 2 len after padding: 240
|
|
completion_ids 2 len after padding: 58
|
|
completion_mask 2 len after padding: 617
|
|
prompt_ids 3 len after padding: 240
|
|
completion_ids 3 len after padding: 61
|
|
completion_mask 3 len after padding: 617
|
|
prompt_ids 4 len after padding: 240
|
|
completion_ids 4 len after padding: 292
|
|
completion_mask 4 len after padding: 617
|
|
prompt_ids 5 len after padding: 240
|
|
completion_ids 5 len after padding: 588
|
|
completion_mask 5 len after padding: 617
|
|
prompt_ids 6 len after padding: 240
|
|
completion_ids 6 len after padding: 617
|
|
completion_mask 6 len after padding: 617
|
|
prompt_ids 7 len after padding: 240
|
|
completion_ids 7 len after padding: 62
|
|
completion_mask 7 len after padding: 617
|
|
```
|
|
|
|
## Bug summarise
|
|
|
|
```bash
|
|
The immediate cause of the crash (TorchRuntimeError) was that the mask tensor had a different sequence length dimension (e.g., 574) than the loss_i tensor (e.g., 294) it was being multiplied with element-wise inside the loss calculation.
|
|
You can't multiply tensors of shape (B, SeqLen1) and (B, SeqLen2) element-wise if SeqLen1 is not equal to SeqLen2. The fix ensures both tensors have the same sequence length before the multiplication happens.
|
|
```
|
|
|
|
```bash
|
|
What Happened: The code crashed with a TorchRuntimeError indicating a shape mismatch during tensor multiplication (loss_i * mask) inside the grpo_compute_loss function, specifically when running under torch.compile.
|
|
|
|
The Core Issue: The completion_mask tensor (representing which completion tokens are valid) was being passed into the loss calculation with a sequence length (e.g., 574) that reflected the initial length of the generated sequence before final processing or slicing. However, the loss_i tensor (representing the per-token loss contribution) was correctly calculated based on the intended completion length (logits_to_keep, e.g., 294).
|
|
```
|
|
|
|
## The Error
|
|
|
|
```bash
|
|
Search results: []
|
|
2025-04-01 13:06:42 | DEBUG | src.rl_helpers_r1_distil:reward_exact_match_chunk_query:745 - Reward for prompt 7: 0.0
|
|
2025-04-01 13:06:42 | INFO | src.rl_helpers_r1_distil:reward_exact_match_chunk_query:781 - Chunk Query Rewards Summary:
|
|
2025-04-01 13:06:42 | INFO | src.rl_helpers_r1_distil:reward_exact_match_chunk_query:782 - Total prompts: 8
|
|
2025-04-01 13:06:42 | INFO | src.rl_helpers_r1_distil:reward_exact_match_chunk_query:783 - Correct matches: 2.0
|
|
2025-04-01 13:06:42 | INFO | src.rl_helpers_r1_distil:reward_exact_match_chunk_query:784 - Average reward: 0.250
|
|
2025-04-01 13:06:42 | INFO | src.rl_helpers_r1_distil:reward_exact_match_chunk_query:785 - Reward std: 0.433
|
|
rewards_per_func: tensor([0.6250, 0.4375, 0.9500, 0.2500], device='cuda:0')
|
|
2025-04-01 13:06:43 | CRITICAL | src.config:exception_handler:132 - Unhandled exception
|
|
Traceback (most recent call last):
|
|
|
|
> File "/home/thinhlpg/code/DeepSearch/train_grpo_r1_distil.py", line 125, in <module>
|
|
trainer.train()
|
|
│ └ <function Trainer.train at 0x7d71f573b560>
|
|
└ <src.UnslothGRPOTrainerTemp.UnslothGRPOTrainer object at 0x7d71982cde10>
|
|
|
|
...
|
|
|
|
raise error_type(message_evaluated)
|
|
│ └ 'The size of tensor a (s4) must match the size of tensor b (s7) at non-singleton dimension 1)'
|
|
└ <class 'RuntimeError'>
|
|
|
|
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function mul>(*(GradTrackingTensor(lvl=1, value=
|
|
FakeTensor(..., device='cuda:0', size=(1, s4))
|
|
), GradTrackingTensor(lvl=1, value=
|
|
FakeTensor(..., device='cuda:0', size=(1, s7))
|
|
)), **{}):
|
|
The size of tensor a (s4) must match the size of tensor b (s7) at non-singleton dimension 1)
|
|
|
|
from user code:
|
|
File "/home/thinhlpg/code/DeepSearch/src/UnslothGRPOTrainerTemp.py", line 186, in accumulate_chunk
|
|
) = torch.func.grad_and_value(
|
|
File "/home/thinhlpg/miniconda3/envs/deepsearch-py311/lib/python3.11/site-packages/torch/_functorch/apis.py", line 442, in wrapper
|
|
return eager_transforms.grad_and_value_impl(
|
|
File "/home/thinhlpg/miniconda3/envs/deepsearch-py311/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 48, in fn
|
|
return f(*args, **kwargs)
|
|
File "/home/thinhlpg/miniconda3/envs/deepsearch-py311/lib/python3.11/site-packages/torch/_functorch/eager_transforms.py", line 1407, in grad_and_value_impl
|
|
output = func(*args, **kwargs)
|
|
File "/home/thinhlpg/code/DeepSearch/src/UnslothGRPOTrainerTemp.py", line 143, in compute_loss
|
|
loss, completion_length, mean_kl = grpo_compute_loss(
|
|
File "/home/thinhlpg/code/DeepSearch/src/UnslothGRPOTrainerTemp.py", line 112, in grpo_compute_loss
|
|
loss = (loss_i * mask).sum() / mask.sum()
|
|
```
|