torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

Seeing - "Recomputed values for the following tensors have different metadata than during the forward pass."

Open githubsgi opened this issue 8 months ago • 19 comments

Seeing the following with the llama4_17bx16e model.

rank11: File ".../lib/python3.10/site-packages/torch/utils/checkpoint.py", line 902, in check_recomputed_tensors_match rank11: raise CheckpointError( rank11: torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass. rank11: tensor at position 46: rank11: saved metadata: {'shape': torch.Size([965, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=3)} rank11: recomputed metadata: {'shape': torch.Size([964, 5120]),[rank13]: Traceback (most recent call last):

githubsgi avatar Apr 17 '25 22:04 githubsgi

could you share the config to reproduce? I think the complaint is from activation checkpointing -- I saw it when exploring possible ways to register hooks for load balancing updates https://github.com/pytorch/torchtitan/pull/1114

tianyu-l avatar Apr 17 '25 23:04 tianyu-l

The toml file follows.

` [job] dump_folder = "./outputs_llama4_17bx16e" description = "Llama 4 Scout 17Bx16E training"

[profiling] enable_profiling = false save_traces_folder = "profile_trace" profile_freq = 100

[metrics] log_freq = 10 enable_tensorboard = false save_tb_folder = "tb"

[model] name = "llama4" flavor = "17bx16e" tokenizer_path = "./assets/tokenizer/original/tokenizer.model"

[optimizer] name = "AdamW" lr = 4e-3 eps = 1e-15

[lr_scheduler] warmup_steps = 600 lr_min = 0.1

[training] batch_size = 1 seq_len = 8192 max_norm = 1.0 # grad norm clipping steps = 3000 compile = false dataset = "c4" dataset_path = "./data/hf/c4"

[parallelism] data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 tensor_parallel_degree = 8 enable_async_tensor_parallel = false pipeline_parallel_degree = 1 context_parallel_degree = 1

[checkpoint] enable_checkpoint = false folder = "checkpoint" interval = 500 model_weights_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

[activation_checkpoint] mode = 'full' # ['none', 'selective', 'full']

[float8] enable_fsdp_float8_all_gather = false precompute_float8_dynamic_scale_for_fsdp = false filter_fqns = "output,router.gate" `

githubsgi avatar Apr 18 '25 00:04 githubsgi

Hmm I'm not sure if I can reproduce. Some more questions:

  • what model configs are you using?
  • are you using Grouped MM or for-loop implementation for MoE? could be depending on your hardware
  • are you running the load balancing I recently added?

tianyu-l avatar Apr 18 '25 02:04 tianyu-l

@tianyu-l , please see the answers below.

what model configs are you using? - The tolml file is above, if that is what you are asking. Otherwise, the source is not changed. are you using Grouped MM or for-loop implementation for MoE? - Not using GroupedMM. are you running the load balancing I recently added? - just tried with the latest commit. Same issue.

githubsgi avatar Apr 18 '25 23:04 githubsgi

Looks like checkpoint recompute has issue/s - this looks funny .

                try:
                    with _recomputation_hook(
                        weakref.ref(frame), gid
                    ), torch.autograd.enable_grad():
                        frame.recompute_fn(*args)
                except _StopRecomputationError:
                    pass
                frame.is_recomputed[gid] = True
                frame.check_recomputed_tensors_match(gid)

One interesting observation is that if I make the following change, training proceeds , but hits NaN after some time.

                try:
                    #print (f"frame.recompute_fn {frame.recompute_fn}")
                    with _recomputation_hook(
                        weakref.ref(frame), gid
                    ), torch.autograd.enable_grad():
                        frame.recompute_fn(*args)
                        print (f" frame.recompute_fn(*args) {args}")
                except _StopRecomputationError :
                    print (f"_StopRecomputationError  {len(args)} {args[0]} {args[1].shape} {args[2].shape} ")
                    pass

The argss[0] is an empty dictionary, and args[1/2] are tensors.

Checkpointing code looks complicated and difficult to debug. It will be good to be able to relate the model layer to the saved-recompute mismatch .

githubsgi avatar Apr 25 '25 02:04 githubsgi

cc @soulitzer in case you have context

tianyu-l avatar Apr 27 '25 15:04 tianyu-l

It will be good to be able to relate the model layer to the saved-recompute mismatch .

Would it be possible to pass debug=True to checkpoint? With this flag enabled, the error message would contain the list of operators that ran during forward/compute and captured stack traces.

soulitzer avatar Apr 28 '25 16:04 soulitzer

@soulitzer you can access the log here with debug=True, llama4_17bx16e N=2 PPN=8 TP=8 FSDP=2 https://github.com/ratnampa/misc_uploads/blob/main/torchtitan/llama4_17bx16e/N2_PPN8_TP8_FSDP2_llama4_17bx16e.log

I have also added log only for rank 1, might be easier to inspect. https://github.com/ratnampa/misc_uploads/blob/main/torchtitan/llama4_17bx16e/rank1.log

ratnampa avatar Apr 30 '25 17:04 ratnampa

Thanks for the logs, it looks like there's a split_with_sizes that received slightly different inputs between the original and recompute.

original:

torch._ops.aten.split_with_sizes.default($181, ['14', '99', '71', '38', '66', '105', '81', '78', '86', '43', '33', '118', '22', '35', '83', '52'])   

recompute:

torch._ops.aten.split_with_sizes.default($245, ['14', '98', '71', '38', '66', '106', '81', '78', '86', '43', '33', '118', '22', '35', '83', '52'])

Any idea why that is the case?

soulitzer avatar Apr 30 '25 17:04 soulitzer

Interesting pointer. Both the original and recompute in the above sum to 1024, but differ in 2 split locations (99 vs 98 a and 105 vs 106). What does the position in the following refer to ?

[rank1]: saved metadata: {'shape': torch.Size([99, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([98, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 49:
[rank1]: saved metadata: {'shape': torch.Size([99, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([98, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 51:
[rank1]: saved metadata: {'shape': torch.Size([99, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([98, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 52:
[rank1]: saved metadata: {'shape': torch.Size([99, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([98, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 53:
[rank1]: saved metadata: {'shape': torch.Size([99ug=True` to `torch.utils.checkpoint.checkpoint()`.
[rank1]: recomputed metadata: {'shape': torch.Size([98, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 55:
[rank1]: saved metadata: {'shape': torch.Size([99, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([98, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 84:
[rank1]: saved metadata: {'shape': torch.Size([105, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([106, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 85:
[rank1]: saved metadata: {'shape': torch.Size([105, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([106, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 87:
[rank1]: saved metadata: {'shape': torch.Size([105, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([106, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 88:
[rank1]: saved metadata: {'shape': torch.Size([105, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([106, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 89:
[rank1]: saved metadata: {'shape': torch.Size([105, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: recomputed metadata: {'shape': torch.Size([106, 8192]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=1)}
[rank1]: tensor at position 91:

githubsgi avatar Apr 30 '25 20:04 githubsgi

position 91 means it is the 91st tensor saved in the checkpointed region

soulitzer avatar Apr 30 '25 20:04 soulitzer

How do I map the position to a layer in the model ? Also, what is the code that decides the split ?

githubsgi avatar Apr 30 '25 20:04 githubsgi

Ah I wouldn't look at the position number here. I'd just search for split_with_sizes and below that you'd find the python and cpp stack traces which should have the module information. In this case, what you're looking for should be /home/ratnampa/torchtitan/torchtitan/experiments/llama4/model/moe.py:40:forward

The way its structured is basically:

op1
stack trace for op1
op2
stack trace for op2
...

soulitzer avatar Apr 30 '25 22:04 soulitzer

@soulitzer , thanks. Added a PyTorch PR for adding a layer identification to checkpoint discrepancies.

githubsgi avatar May 07 '25 02:05 githubsgi

Thanks for the PR, I added a comment here https://github.com/pytorch/pytorch/pull/153021#pullrequestreview-2822708017.

soulitzer avatar May 07 '25 17:05 soulitzer

Few questions.

  1. Is there any more design/rfc docs on activation checkpointing other than this ?
  2. The ac metadata is stored in CPU ? I guess the saved activations are left in the accelerators ?
  3. I do see differences in the input to layers (e.g. x) between forward and recompute. Where could that come from ? Could the RNG state play a role here ?
  4. What is the best way to not do ac on specific layers ?
  5. I see the selective_ac_option, what is an example of using the "op" option ?

githubsgi avatar May 15 '25 01:05 githubsgi

Is there any more design/rfc docs on activation checkpointing other than

Not a lot. What type of information are you looking for? There's a code comment on some AC internals here https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py#L602.

The ac metadata is stored in CPU ? I guess the saved activations are left in the accelerators ?

Yes

I do see differences in the input to layers (e.g. x) between forward and recompute. Where could that come from ? Could the RNG state play a role here ?

Your forward logic depending on global state, e.g. are you explicitly branching on any globals, are there modes TorchDispatchModes/TorchFunctionModes

What is the best way to not do ac on specific layers ?

Don't think it's possible in TorchTitan through the config (@tianyu-l correct me if I'm wrong)

I see the selective_ac_option, what is an example of using the "op" option ?

It always saves some "compute intensive ops" except every other matmul. https://github.com/pytorch/torchtitan/blob/3b85aa31fffc46ecbf785a57ee314a01614f572f/torchtitan/models/llama3/parallelize_llama.py#L241

soulitzer avatar May 15 '25 03:05 soulitzer

What is the best way to not do ac on specific layers ?

I believe you can always define your own apply_ac method https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/parallelize_llama.py#L292

tianyu-l avatar May 15 '25 07:05 tianyu-l

@tianyu-l , @soulitzer

I do see differences in the input to layers (e.g. x) between forward and recompute. Where could that come from ? Could the RNG state play a role here ?

Your forward logic depending on global state, e.g. are you explicitly branching on any globals, are there modes TorchDispatchModes/TorchFunctionModes

Not sure. It is the Llama4 model OOB.

githubsgi avatar May 15 '25 19:05 githubsgi

@soulitzer I think this issue could likely be related to https://github.com/pytorch/torchtitan/issues/1323#issuecomment-3145324389 We should really save the routing results, whichever ops they come from.

tianyu-l avatar Aug 21 '25 07:08 tianyu-l

@tianyu-l and @soulitzer , I was thinking about trying preserve_rng_state=True, The recompute difference shows only for larger networks with MoE . Also appears that distributed is not necessary for this to show up.

githubsgi avatar Aug 21 '25 17:08 githubsgi

Not necessarily preserve_rng_state, but likely some numerical difference that has been magnified, followed by a op that produces data-dependent shapes. Unassigning because there's unlikely to be anything on the AC framework side that can be done. However, as tianyu points out, saving intermediate results may be able to help.

soulitzer avatar Aug 21 '25 18:08 soulitzer

It is the MoE router that sprays the tokens to different experts that exposes this issue readily.

githubsgi avatar Aug 21 '25 18:08 githubsgi

@githubsgi do you know it's

  • the router.gate https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/moe.py#L211
  • or sigmoid https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/moe.py#L215
  • or topk https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/moe.py#L225 that gives this discrepancy between forward and backward recomputation?

tianyu-l avatar Aug 21 '25 19:08 tianyu-l

@tianyu-l , it is hard to say from the debug log where is divergence stated. The difference shows up in .split_with_sizes.default output.

[rank4]: ['$182: bf16[63, 5120]', '$183: bf16[65, 5120]'] = torch._ops.aten.split_with_sizes.default($181, ['63', '65'])   (140 of 177 in original: )
[rank4]: ['$246: bf16[62, 5120]', '$247: bf16[66, 5120]'] = torch._ops.aten.split_with_sizes.default($245, ['62', '66'])   (204 of 285 in recompute: )


[rank4]: $154: bf16[128, 2] = torch._ops.aten.view.default($153, ['128', '2'])
[rank4]: $155: f32[128, 2] = torch._ops.aten._to_copy.default($154, dtype=torch.float32)
[rank4]: $156: f32[128, 2] = torch._ops.aten.sigmoid.default($155)
[rank4]: $158: f32[128, 2] = torch._ops.aten.add.Tensor($156, $157)
[rank4]: ('$159: f32[128, 1]', '$160: i64[128, 1]') = torch._ops.aten.topk.default($158, 1, 1)
[rank4]: $161: f32[128, 1] = torch._ops.aten.gather.default($156, 1, $160)
[rank4]: $162: i64[128] = torch._ops.aten.view.default($160, ['-1'])
[rank4]: $163: i64[2] = torch._ops.aten.histc.default($162, 2, 0, 2)
[rank4]: $164: i64[128] = torch._ops.aten.view.default($160, ['-1'])
[rank4]: ('$165: i64[128]', '$166: i64[128]') = torch._ops.aten.sort.stable($164, stable=True)
[rank4]: $167: f32[128] = torch._ops.aten.view.default($161, ['-1'])
[rank4]: $168: f32[128] = torch._ops.aten.index.Tensor($167, ['$166'])
[rank4]: $169: i64[128] = torch._ops.aten.floor_divide.default($166, 1)
[rank4]: $170: f32[2] = torch._ops.aten.add_.Tensor($170, $163)
[rank4]: $171: i64[128, 1] = torch._ops.aten.view.default($169, ['-1', '1'])
[rank4]: $172: i64[128, 5120] = torch._ops.aten.expand.default($171, ['-1', '5120'])
[rank4]: $173: bf16[128, 5120] = torch._ops.aten.view.default($146, ['-1', '5120'])
[rank4]: $174: bf16[128, 5120] = torch._ops.aten.gather.default($173, 0, $172)
[rank4]: $175: f32[128, 5120] = torch._ops.aten._to_copy.default($174, dtype=torch.float32)
[rank4]: $176: f32[128, 1] = torch._ops.aten.view.default($168, ['-1', '1'])
[rank4]: $177: f32[128, 5120] = torch._ops.aten.mul.Tensor($175, $176)
[rank4]: $178: bf16[128, 5120] = torch._ops.aten._to_copy.default($177, dtype=torch.bfloat16)
[rank4]: $179: i64[2] = torch._ops.aten._to_copy.default($163, dtype=torch.int64, layout=torch.strided, device=device(type='cpu'))
[rank4]: $180: bf16[128, 5120] = torch._ops.aten.view.default($178, ['128', '5120'])
[rank4]: ['$182: bf16[63, 5120]', '$183: bf16[65, 5120]'] = torch._ops.aten.split_with_sizes.default($181, ['63', '65'])



[rank4]: $208: bf16[128, 2] = torch._ops.aten.view.default($207, ['128', '2'])
[rank4]: $209: f32[128, 2] = torch._ops.aten._to_copy.default($208, dtype=torch.float32)
[rank4]: $210: f32[128, 2] = torch._ops.aten.sigmoid.default($209)
[rank4]: $211: f32[128, 2] = torch._ops.aten.detach.default($210)
[rank4]: $212: f32[128, 2] = torch._ops.aten.detach.default($211)
[rank4]: $214: f32[128, 2] = torch._ops.aten.add.Tensor($210, $213)
[rank4]: ('$215: f32[128, 1]', '$216: i64[128, 1]') = torch._ops.aten.topk.default($214, 1, 1)
[rank4]: $217: f32[128, 2] = torch._ops.aten.detach.default($210)
[rank4]: $218: f32[128, 2] = torch._ops.aten.detach.default($217)
[rank4]: $219: f32[128, 1] = torch._ops.aten.gather.default($210, 1, $216)
[rank4]: $220: i64[128] = torch._ops.aten.view.default($216, ['-1'])
[rank4]: $221: i64[2] = torch._ops.aten.histc.default($220, 2, 0, 2)
[rank4]: $222: i64[128] = torch._ops.aten.view.default($216, ['-1'])
[rank4]: ('$223: i64[128]', '$224: i64[128]') = torch._ops.aten.sort.stable($222, stable=True)
[rank4]: $225: f32[128] = torch._ops.aten.view.default($219, ['-1'])
[rank4]: $226: f32[128] = torch._ops.aten.index.Tensor($225, ['$224'])
[rank4]: $227: i64[128] = torch._ops.aten.floor_divide.default($224, 1)
[rank4]: $228: f32[2] = torch._ops.aten.add_.Tensor($228, $221)
[rank4]: $229: i64[128, 1] = torch._ops.aten.view.default($227, ['-1', '1'])
[rank4]: $230: i64[128, 5120] = torch._ops.aten.expand.default($229, ['-1', '5120'])
[rank4]: $231: bf16[128, 5120] = torch._ops.aten.view.default($196, ['-1', '5120'])
[rank4]: $232: bf16[128, 5120] = torch._ops.aten.detach.default($231)
[rank4]: $233: bf16[128, 5120] = torch._ops.aten.detach.default($232)
[rank4]: $234: bf16[128, 5120] = torch._ops.aten.gather.default($231, 0, $230)
[rank4]: $235: f32[128, 5120] = torch._ops.aten._to_copy.default($234, dtype=torch.float32)
[rank4]: $236: f32[128, 1] = torch._ops.aten.view.default($226, ['-1', '1'])
[rank4]: $237: f32[128, 1] = torch._ops.aten.detach.default($236)
[rank4]: $238: f32[128, 1] = torch._ops.aten.detach.default($237)
[rank4]: $239: f32[128, 5120] = torch._ops.aten.detach.default($235)
[rank4]: $240: f32[128, 5120] = torch._ops.aten.detach.default($239)
[rank4]: $241: f32[128, 5120] = torch._ops.aten.mul.Tensor($235, $236)
[rank4]: $242: bf16[128, 5120] = torch._ops.aten._to_copy.default($241, dtype=torch.bfloat16)
[rank4]: $243: i64[2] = torch._ops.aten._to_copy.default($221, dtype=torch.int64, layout=torch.strided, device=device(type='cpu'))
[rank4]: $244: bf16[128, 5120] = torch._ops.aten.view.default($242, ['128', '5120'])
[rank4]: ['$246: bf16[62, 5120]', '$247: bf16[66, 5120]'] = torch._ops.aten.split_with_sizes.default($245, ['62', '66'])


`

githubsgi avatar Aug 21 '25 23:08 githubsgi

@githubsgi I would guess if you checkpoint torch._ops.aten.topk.default or torch._ops.aten.sigmoid.default (by putting an op into the _save_list https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/infra/parallelize.py#L237), or the router.gate mm (by following https://github.com/pytorch/torchtitan/pull/1589) the differences would be gone.

It would be nice if you could try and see which works.

tianyu-l avatar Aug 21 '25 23:08 tianyu-l

Added the following to the save list. Needed to use full recompute tough. Still see recompute diff.

    torch.ops.aten.topk.default,
    torch.ops.aten.sigmoid.default

@tianyu-l , @soulitzer , couple of questions on the debug log.

  1. Is the number of operators in forward and recompute expected to be same ? Seeing $460 vs $815. How should they line up ?
[rank19]: torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.
[rank19]: saved metadata: {'shape': torch.Size([696, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=3)}
[rank19]: recomputed metadata: {'shape': torch.Size([704, 5120]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=3)}
[rank19]: saved metadata: {'shape': torch.Size([696, 1024]), 'dtype': torch.bfloat16, 'device': device(type='xpu', index=3)}
.
.
.
[rank19]: $448: bf16[1, 8192, 5120] = torch._ops.aten.view.default($447, ['1', '8192', '5120'])   (409 of 415 in original)
[rank19]: $457: bf16[8, 1024, 5120] = torch._ops.aten.cat.default(['$449', '$450', '$451', '$452', '$453', '$454', '$455', '$456'])   (411 of 415 in original)
[rank19]: $458: bf16[1, 1024, 5120] = torch._ops._c10d_functional.reduce_scatter_tensor.default($457, 'sum', 8, '11')   (412 of 415 in original)
[rank19]: $458: bf16[1, 1024, 5120] = torch._ops._c10d_functional.wait_tensor.default($458)   (413 of 415 in original)
[rank19]: $459: bf16[1, 1024, 5120] = torch._ops.aten.view.default($458, ['1', '1024', '5120'])   (414 of 415 in original)
[rank19]: $460: bf16[1, 1024, 5120] = torch._ops.aten.add.Tensor($123, $459)   (415 of 415 in original)

  1. Looks like recompute is occurring twice. Is that the right way to interpret the following ?
 [rank19]: $809: bf16[8192, 5120] = torch._ops.aten.detach.default($808)   (772 of 776 in recompute)
[rank19]: $811: bf16[8192, 8192] = torch._ops.aten.detach.default($810)   (773 of 776 in recompute)
[rank19]: $812: bf16[8192, 8192] = torch._ops.aten.detach.default($811)   (774 of 776 in recompute)
[rank19]: $813: bf16[8192, 5120] = torch._ops.aten.mm.default($810, $807)   (775 of 776 in recompute)
[rank19]: $815: bf16[8192, 5120] = torch._ops.aten.view.default($814, ['8192', '5120'])   (776 of 776 in recompute)
.
.
.
[rank19]: $809: bf16[8192, 5120] = torch._ops.aten.detach.default($808)
[rank19]: $811: bf16[8192, 8192] = torch._ops.aten.detach.default($810)
[rank19]: $812: bf16[8192, 8192] = torch._ops.aten.detach.default($811)
[rank19]: $813: bf16[8192, 5120] = torch._ops.aten.mm.default($810, $807)
[rank19]: $815: bf16[8192, 5120] = torch._ops.aten.view.default($814, ['8192', '5120'])

githubsgi avatar Aug 27 '25 19:08 githubsgi

Is the number of operators in forward and recompute expected to be same ?

Not exactly, for normal AC, only the number/order/properties of tensors saved by those tensors is expected to be the same. For SAC we DO have that requirement for the ops that are marked saved.

Looks like recompute is occurring twice. Is that the right way to interpret the following ?

Doesn't seem like it. I think the logs are just repeated there.

soulitzer avatar Aug 28 '25 19:08 soulitzer