Seeing - "Recomputed values for the following tensors have different metadata than during the forward pass."
Seeing the following with the llama4_17bx16e model.
rank11: File ".../lib/python3.10/site-packages/torch/utils/checkpoint.py", line 902, in check_recomputed_tensors_match rank11: raise CheckpointError( rank11: torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass. rank11: tensor at position 46: rank11: saved metadata: {'shape': torch.Size([965, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=3)} rank11: recomputed metadata: {'shape': torch.Size([964, 5120]),[rank13]: Traceback (most recent call last):
could you share the config to reproduce? I think the complaint is from activation checkpointing -- I saw it when exploring possible ways to register hooks for load balancing updates https://github.com/pytorch/torchtitan/pull/1114
The toml file follows.
` [job] dump_folder = "./outputs_llama4_17bx16e" description = "Llama 4 Scout 17Bx16E training"
[profiling] enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 100
[metrics] log_freq = 10 enable_tensorboard = false save_tb_folder = "tb"
[model] name = "llama4" flavor = "17bx16e" tokenizer_path = "./assets/tokenizer/original/tokenizer.model"
[optimizer] name = "AdamW" lr = 4e-3 eps = 1e-15
[lr_scheduler] warmup_steps = 600 lr_min = 0.1
[training] batch_size = 1 seq_len = 8192 max_norm = 1.0 # grad norm clipping steps = 3000 compile = false dataset = "c4" dataset_path = "./data/hf/c4"
[parallelism] data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 tensor_parallel_degree = 8 enable_async_tensor_parallel = false pipeline_parallel_degree = 1 context_parallel_degree = 1
[checkpoint] enable_checkpoint = false folder = "checkpoint" interval = 500 model_weights_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
[activation_checkpoint] mode = 'full' # ['none', 'selective', 'full']
[float8] enable_fsdp_float8_all_gather = false precompute_float8_dynamic_scale_for_fsdp = false filter_fqns = "output,router.gate" `
Hmm I'm not sure if I can reproduce. Some more questions:
- what model configs are you using?
- are you using Grouped MM or for-loop implementation for MoE? could be depending on your hardware
- are you running the load balancing I recently added?
@tianyu-l , please see the answers below.
what model configs are you using? - The tolml file is above, if that is what you are asking. Otherwise, the source is not changed. are you using Grouped MM or for-loop implementation for MoE? - Not using GroupedMM. are you running the load balancing I recently added? - just tried with the latest commit. Same issue.
Looks like checkpoint recompute has issue/s - this looks funny .
try:
with _recomputation_hook(
weakref.ref(frame), gid
), torch.autograd.enable_grad():
frame.recompute_fn(*args)
except _StopRecomputationError:
pass
frame.is_recomputed[gid] = True
frame.check_recomputed_tensors_match(gid)
One interesting observation is that if I make the following change, training proceeds , but hits NaN after some time.
try:
#print (f"frame.recompute_fn {frame.recompute_fn}")
with _recomputation_hook(
weakref.ref(frame), gid
), torch.autograd.enable_grad():
frame.recompute_fn(*args)
print (f" frame.recompute_fn(*args) {args}")
except _StopRecomputationError :
print (f"_StopRecomputationError {len(args)} {args[0]} {args[1].shape} {args[2].shape} ")
pass
The argss[0] is an empty dictionary, and args[1/2] are tensors.
Checkpointing code looks complicated and difficult to debug. It will be good to be able to relate the model layer to the saved-recompute mismatch .
cc @soulitzer in case you have context
It will be good to be able to relate the model layer to the saved-recompute mismatch .
Would it be possible to pass debug=True to checkpoint? With this flag enabled, the error message would contain the list of operators that ran during forward/compute and captured stack traces.
@soulitzer you can access the log here with debug=True, llama4_17bx16e N=2 PPN=8 TP=8 FSDP=2 https://github.com/ratnampa/misc_uploads/blob/main/torchtitan/llama4_17bx16e/N2_PPN8_TP8_FSDP2_llama4_17bx16e.log
I have also added log only for rank 1, might be easier to inspect. https://github.com/ratnampa/misc_uploads/blob/main/torchtitan/llama4_17bx16e/rank1.log
Thanks for the logs, it looks like there's a split_with_sizes that received slightly different inputs between the original and recompute.
original:
torch._ops.aten.split_with_sizes.default($181, ['14', '99', '71', '38', '66', '105', '81', '78', '86', '43', '33', '118', '22', '35', '83', '52'])
recompute:
torch._ops.aten.split_with_sizes.default($245, ['14', '98', '71', '38', '66', '106', '81', '78', '86', '43', '33', '118', '22', '35', '83', '52'])
Any idea why that is the case?
Interesting pointer. Both the original and recompute in the above sum to 1024, but differ in 2 split locations (99 vs 98 a and 105 vs 106). What does the position in the following refer to ?
[rank1]: saved metadata: {'shape': torch.Size([99, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([98, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 49:
[rank1]: saved metadata: {'shape': torch.Size([99, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([98, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 51:
[rank1]: saved metadata: {'shape': torch.Size([99, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([98, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 52:
[rank1]: saved metadata: {'shape': torch.Size([99, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([98, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 53:
[rank1]: saved metadata: {'shape': torch.Size([99ug=True` to `torch.utils.checkpoint.checkpoint()`.
[rank1]: recomputed metadata: {'shape': torch.Size([98, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 55:
[rank1]: saved metadata: {'shape': torch.Size([99, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([98, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 84:
[rank1]: saved metadata: {'shape': torch.Size([105, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([106, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 85:
[rank1]: saved metadata: {'shape': torch.Size([105, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([106, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 87:
[rank1]: saved metadata: {'shape': torch.Size([105, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([106, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 88:
[rank1]: saved metadata: {'shape': torch.Size([105, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([106, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 89:
[rank1]: saved metadata: {'shape': torch.Size([105, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([106, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 91:
position 91 means it is the 91st tensor saved in the checkpointed region
How do I map the position to a layer in the model ? Also, what is the code that decides the split ?
Ah I wouldn't look at the position number here. I'd just search for split_with_sizes and below that you'd find the python and cpp stack traces which should have the module information. In this case, what you're looking for should be /home/ratnampa/torchtitan/torchtitan/experiments/llama4/model/moe.py:40:forward
The way its structured is basically:
op1
stack trace for op1
op2
stack trace for op2
...
@soulitzer , thanks. Added a PyTorch PR for adding a layer identification to checkpoint discrepancies.
Thanks for the PR, I added a comment here https://github.com/pytorch/pytorch/pull/153021#pullrequestreview-2822708017.
Few questions.
- Is there any more design/rfc docs on activation checkpointing other than this ?
- The ac metadata is stored in CPU ? I guess the saved activations are left in the accelerators ?
- I do see differences in the input to layers (e.g. x) between forward and recompute. Where could that come from ? Could the RNG state play a role here ?
- What is the best way to not do ac on specific layers ?
- I see the selective_ac_option, what is an example of using the "op" option ?
Is there any more design/rfc docs on activation checkpointing other than
Not a lot. What type of information are you looking for? There's a code comment on some AC internals here https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py#L602.
The ac metadata is stored in CPU ? I guess the saved activations are left in the accelerators ?
Yes
I do see differences in the input to layers (e.g. x) between forward and recompute. Where could that come from ? Could the RNG state play a role here ?
Your forward logic depending on global state, e.g. are you explicitly branching on any globals, are there modes TorchDispatchModes/TorchFunctionModes
What is the best way to not do ac on specific layers ?
Don't think it's possible in TorchTitan through the config (@tianyu-l correct me if I'm wrong)
I see the selective_ac_option, what is an example of using the "op" option ?
It always saves some "compute intensive ops" except every other matmul. https://github.com/pytorch/torchtitan/blob/3b85aa31fffc46ecbf785a57ee314a01614f572f/torchtitan/models/llama3/parallelize_llama.py#L241
What is the best way to not do ac on specific layers ?
I believe you can always define your own apply_ac method
https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/parallelize_llama.py#L292
@tianyu-l , @soulitzer
I do see differences in the input to layers (e.g. x) between forward and recompute. Where could that come from ? Could the RNG state play a role here ?
Your forward logic depending on global state, e.g. are you explicitly branching on any globals, are there modes TorchDispatchModes/TorchFunctionModes
Not sure. It is the Llama4 model OOB.
@soulitzer I think this issue could likely be related to https://github.com/pytorch/torchtitan/issues/1323#issuecomment-3145324389 We should really save the routing results, whichever ops they come from.
@tianyu-l and @soulitzer , I was thinking about trying preserve_rng_state=True, The recompute difference shows only for larger networks with MoE . Also appears that distributed is not necessary for this to show up.
Not necessarily preserve_rng_state, but likely some numerical difference that has been magnified, followed by a op that produces data-dependent shapes. Unassigning because there's unlikely to be anything on the AC framework side that can be done. However, as tianyu points out, saving intermediate results may be able to help.
It is the MoE router that sprays the tokens to different experts that exposes this issue readily.
@githubsgi do you know it's
- the
router.gatehttps://github.com/pytorch/torchtitan/blob/main/torchtitan/models/moe.py#L211 - or
sigmoidhttps://github.com/pytorch/torchtitan/blob/main/torchtitan/models/moe.py#L215 - or
topkhttps://github.com/pytorch/torchtitan/blob/main/torchtitan/models/moe.py#L225 that gives this discrepancy between forward and backward recomputation?
@tianyu-l , it is hard to say from the debug log where is divergence stated. The difference shows up in .split_with_sizes.default output.
[rank4]: ['$182: bf16[63, 5120]', '$183: bf16[65, 5120]'] = torch._ops.aten.split_with_sizes.default($181, ['63', '65']) (140 of 177 in original: )
[rank4]: ['$246: bf16[62, 5120]', '$247: bf16[66, 5120]'] = torch._ops.aten.split_with_sizes.default($245, ['62', '66']) (204 of 285 in recompute: )
[rank4]: $154: bf16[128, 2] = torch._ops.aten.view.default($153, ['128', '2'])
[rank4]: $155: f32[128, 2] = torch._ops.aten._to_copy.default($154, dtype=torch.float32)
[rank4]: $156: f32[128, 2] = torch._ops.aten.sigmoid.default($155)
[rank4]: $158: f32[128, 2] = torch._ops.aten.add.Tensor($156, $157)
[rank4]: ('$159: f32[128, 1]', '$160: i64[128, 1]') = torch._ops.aten.topk.default($158, 1, 1)
[rank4]: $161: f32[128, 1] = torch._ops.aten.gather.default($156, 1, $160)
[rank4]: $162: i64[128] = torch._ops.aten.view.default($160, ['-1'])
[rank4]: $163: i64[2] = torch._ops.aten.histc.default($162, 2, 0, 2)
[rank4]: $164: i64[128] = torch._ops.aten.view.default($160, ['-1'])
[rank4]: ('$165: i64[128]', '$166: i64[128]') = torch._ops.aten.sort.stable($164, stable=True)
[rank4]: $167: f32[128] = torch._ops.aten.view.default($161, ['-1'])
[rank4]: $168: f32[128] = torch._ops.aten.index.Tensor($167, ['$166'])
[rank4]: $169: i64[128] = torch._ops.aten.floor_divide.default($166, 1)
[rank4]: $170: f32[2] = torch._ops.aten.add_.Tensor($170, $163)
[rank4]: $171: i64[128, 1] = torch._ops.aten.view.default($169, ['-1', '1'])
[rank4]: $172: i64[128, 5120] = torch._ops.aten.expand.default($171, ['-1', '5120'])
[rank4]: $173: bf16[128, 5120] = torch._ops.aten.view.default($146, ['-1', '5120'])
[rank4]: $174: bf16[128, 5120] = torch._ops.aten.gather.default($173, 0, $172)
[rank4]: $175: f32[128, 5120] = torch._ops.aten._to_copy.default($174, dtype=torch.float32)
[rank4]: $176: f32[128, 1] = torch._ops.aten.view.default($168, ['-1', '1'])
[rank4]: $177: f32[128, 5120] = torch._ops.aten.mul.Tensor($175, $176)
[rank4]: $178: bf16[128, 5120] = torch._ops.aten._to_copy.default($177, dtype=torch.bfloat16)
[rank4]: $179: i64[2] = torch._ops.aten._to_copy.default($163, dtype=torch.int64, layout=torch.strided, device=device(type='cpu'))
[rank4]: $180: bf16[128, 5120] = torch._ops.aten.view.default($178, ['128', '5120'])
[rank4]: ['$182: bf16[63, 5120]', '$183: bf16[65, 5120]'] = torch._ops.aten.split_with_sizes.default($181, ['63', '65'])
[rank4]: $208: bf16[128, 2] = torch._ops.aten.view.default($207, ['128', '2'])
[rank4]: $209: f32[128, 2] = torch._ops.aten._to_copy.default($208, dtype=torch.float32)
[rank4]: $210: f32[128, 2] = torch._ops.aten.sigmoid.default($209)
[rank4]: $211: f32[128, 2] = torch._ops.aten.detach.default($210)
[rank4]: $212: f32[128, 2] = torch._ops.aten.detach.default($211)
[rank4]: $214: f32[128, 2] = torch._ops.aten.add.Tensor($210, $213)
[rank4]: ('$215: f32[128, 1]', '$216: i64[128, 1]') = torch._ops.aten.topk.default($214, 1, 1)
[rank4]: $217: f32[128, 2] = torch._ops.aten.detach.default($210)
[rank4]: $218: f32[128, 2] = torch._ops.aten.detach.default($217)
[rank4]: $219: f32[128, 1] = torch._ops.aten.gather.default($210, 1, $216)
[rank4]: $220: i64[128] = torch._ops.aten.view.default($216, ['-1'])
[rank4]: $221: i64[2] = torch._ops.aten.histc.default($220, 2, 0, 2)
[rank4]: $222: i64[128] = torch._ops.aten.view.default($216, ['-1'])
[rank4]: ('$223: i64[128]', '$224: i64[128]') = torch._ops.aten.sort.stable($222, stable=True)
[rank4]: $225: f32[128] = torch._ops.aten.view.default($219, ['-1'])
[rank4]: $226: f32[128] = torch._ops.aten.index.Tensor($225, ['$224'])
[rank4]: $227: i64[128] = torch._ops.aten.floor_divide.default($224, 1)
[rank4]: $228: f32[2] = torch._ops.aten.add_.Tensor($228, $221)
[rank4]: $229: i64[128, 1] = torch._ops.aten.view.default($227, ['-1', '1'])
[rank4]: $230: i64[128, 5120] = torch._ops.aten.expand.default($229, ['-1', '5120'])
[rank4]: $231: bf16[128, 5120] = torch._ops.aten.view.default($196, ['-1', '5120'])
[rank4]: $232: bf16[128, 5120] = torch._ops.aten.detach.default($231)
[rank4]: $233: bf16[128, 5120] = torch._ops.aten.detach.default($232)
[rank4]: $234: bf16[128, 5120] = torch._ops.aten.gather.default($231, 0, $230)
[rank4]: $235: f32[128, 5120] = torch._ops.aten._to_copy.default($234, dtype=torch.float32)
[rank4]: $236: f32[128, 1] = torch._ops.aten.view.default($226, ['-1', '1'])
[rank4]: $237: f32[128, 1] = torch._ops.aten.detach.default($236)
[rank4]: $238: f32[128, 1] = torch._ops.aten.detach.default($237)
[rank4]: $239: f32[128, 5120] = torch._ops.aten.detach.default($235)
[rank4]: $240: f32[128, 5120] = torch._ops.aten.detach.default($239)
[rank4]: $241: f32[128, 5120] = torch._ops.aten.mul.Tensor($235, $236)
[rank4]: $242: bf16[128, 5120] = torch._ops.aten._to_copy.default($241, dtype=torch.bfloat16)
[rank4]: $243: i64[2] = torch._ops.aten._to_copy.default($221, dtype=torch.int64, layout=torch.strided, device=device(type='cpu'))
[rank4]: $244: bf16[128, 5120] = torch._ops.aten.view.default($242, ['128', '5120'])
[rank4]: ['$246: bf16[62, 5120]', '$247: bf16[66, 5120]'] = torch._ops.aten.split_with_sizes.default($245, ['62', '66'])
`
@githubsgi
I would guess if you checkpoint torch._ops.aten.topk.default or torch._ops.aten.sigmoid.default (by putting an op into the _save_list https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/infra/parallelize.py#L237), or the router.gate mm (by following https://github.com/pytorch/torchtitan/pull/1589) the differences would be gone.
It would be nice if you could try and see which works.
Added the following to the save list. Needed to use full recompute tough. Still see recompute diff.
torch.ops.aten.topk.default,
torch.ops.aten.sigmoid.default
@tianyu-l , @soulitzer , couple of questions on the debug log.
- Is the number of operators in forward and recompute expected to be same ? Seeing $460 vs $815. How should they line up ?
[rank19]: torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.
[rank19]: saved metadata: {'shape': torch.Size([696, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=3)}
[rank19]: recomputed metadata: {'shape': torch.Size([704, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=3)}
[rank19]: saved metadata: {'shape': torch.Size([696, 1024]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=3)}
.
.
.
[rank19]: $448: bf16[1, 8192, 5120] = torch._ops.aten.view.default($447, ['1', '8192', '5120']) (409 of 415 in original)
[rank19]: $457: bf16[8, 1024, 5120] = torch._ops.aten.cat.default(['$449', '$450', '$451', '$452', '$453', '$454', '$455', '$456']) (411 of 415 in original)
[rank19]: $458: bf16[1, 1024, 5120] = torch._ops._c10d_functional.reduce_scatter_tensor.default($457, 'sum', 8, '11') (412 of 415 in original)
[rank19]: $458: bf16[1, 1024, 5120] = torch._ops._c10d_functional.wait_tensor.default($458) (413 of 415 in original)
[rank19]: $459: bf16[1, 1024, 5120] = torch._ops.aten.view.default($458, ['1', '1024', '5120']) (414 of 415 in original)
[rank19]: $460: bf16[1, 1024, 5120] = torch._ops.aten.add.Tensor($123, $459) (415 of 415 in original)
- Looks like recompute is occurring twice. Is that the right way to interpret the following ?
[rank19]: $809: bf16[8192, 5120] = torch._ops.aten.detach.default($808) (772 of 776 in recompute)
[rank19]: $811: bf16[8192, 8192] = torch._ops.aten.detach.default($810) (773 of 776 in recompute)
[rank19]: $812: bf16[8192, 8192] = torch._ops.aten.detach.default($811) (774 of 776 in recompute)
[rank19]: $813: bf16[8192, 5120] = torch._ops.aten.mm.default($810, $807) (775 of 776 in recompute)
[rank19]: $815: bf16[8192, 5120] = torch._ops.aten.view.default($814, ['8192', '5120']) (776 of 776 in recompute)
.
.
.
[rank19]: $809: bf16[8192, 5120] = torch._ops.aten.detach.default($808)
[rank19]: $811: bf16[8192, 8192] = torch._ops.aten.detach.default($810)
[rank19]: $812: bf16[8192, 8192] = torch._ops.aten.detach.default($811)
[rank19]: $813: bf16[8192, 5120] = torch._ops.aten.mm.default($810, $807)
[rank19]: $815: bf16[8192, 5120] = torch._ops.aten.view.default($814, ['8192', '5120'])
Is the number of operators in forward and recompute expected to be same ?
Not exactly, for normal AC, only the number/order/properties of tensors saved by those tensors is expected to be the same. For SAC we DO have that requirement for the ops that are marked saved.
Looks like recompute is occurring twice. Is that the right way to interpret the following ?
Doesn't seem like it. I think the logs are just repeated there.