lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Unknown attribute _base inside Megatron core

Open tfogal opened this issue 1 year ago • 12 comments

🚀 Model / language coverage

Running Megatron GPT from NeMo, we seem to have issues with this line from Megatron core. Some context from the caller in this particular case might be illuminating.

[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/tfogal/dev/nemo/examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py", line 88, in <module>
[rank0]:     main()
[rank0]:   File "/home/tfogal/dev/nemo/nemo/core/config/hydra_runner.py", line 129, in wrapper
[rank0]:     _run_hydra(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 394, in _run_hydra
[rank0]:     _run_app(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 457, in _run_app
[rank0]:     run_and_report(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 223, in run_and_report
[rank0]:     raise ex
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 220, in run_and_report
[rank0]:     return func()
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/utils.py", line 458, in <lambda>
[rank0]:     lambda: hydra.run(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/_internal/hydra.py", line 132, in run
[rank0]:     _ = ret.return_value
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/core/utils.py", line 260, in return_value
[rank0]:     raise self._return_value
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/hydra/core/utils.py", line 186, in run_job
[rank0]:     ret.return_value = task_function(task_cfg)
[rank0]:   File "/home/tfogal/dev/nemo/examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py", line 84, in main
[rank0]:     trainer.fit(model)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 543, in fit
[rank0]:     call._call_and_handle_interrupt(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py", line 43, in _call_and_handle_interrupt
[rank0]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 105, in launch
[rank0]:     return function(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 579, in _fit_impl
[rank0]:     self._run(model, ckpt_path=ckpt_path)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 986, in _run
[rank0]:     results = self._run_stage()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 1028, in _run_stage
[rank0]:     self._run_sanity_check()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/trainer.py", line 1057, in _run_sanity_check
[rank0]:     val_loop.run()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/utilities.py", line 182, in _decorator
[rank0]:     return loop_run(self, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/evaluation_loop.py", line 135, in run
[rank0]:     self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/loops/evaluation_loop.py", line 396, in _evaluation_step
[rank0]:     output = call._call_strategy_hook(trainer, hook_name, *step_args)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py", line 311, in _call_strategy_hook
[rank0]:     output = fn(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/pytorch_lightning/strategies/strategy.py", line 411, in validation_step
[rank0]:     return self.lightning_module.validation_step(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py", line 434, in validation_step
[rank0]:     return self.inference_step(dataloader_iter, 'validation')
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py", line 444, in inference_step
[rank0]:     outputs = self.inference_step_validation_call(batch, batch_idx, data_cfg, dataloader_idx)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py", line 464, in inference_step_validation_call
[rank0]:     loss = super().validation_step(itertools.chain([batch]), dataloader_idx)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py", line 1370, in validation_step
[rank0]:     loss = self.fwd_bwd_step(dataloader_iter, True, first_val_step)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py", line 381, in fwd_bwd_step
[rank0]:     losses_reduced_per_micro_batch = fwd_bwd_function(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/megatron/core/pipeline_parallel/schedules.py", line 392, in forward_backward_no_pipelining
[rank0]:     output_tensor, num_tokens = forward_step(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/megatron/core/pipeline_parallel/schedules.py", line 217, in forward_step
[rank0]:     output_tensor, loss_func = forward_step_func(data_iterator, model)
[rank0]:   File "/home/tfogal/dev/nemo/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py", line 1232, in fwd_output_and_loss_func
[rank0]:     output_tensor = model(**forward_args)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1727, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/module.py", line 61, in forward
[rank0]:     res = self._forward_fn(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/__init__.py", line 669, in fn_
[rank0]:     cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/__init__.py", line 223, in cache_info_wrapper
[rank0]:     res = fn(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/__init__.py", line 503, in get_computation_and_inputs
[rank0]:     jit_results: TraceResults = interpreter(
[rank0]:   File "/home/tfogal/dev/thunder/thunder/__init__.py", line 211, in _general_frontend
[rank0]:     return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/jit_ext.py", line 1743, in thunder_general_jit
[rank0]:     result = jfn(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6695, in fn_
[rank0]:     raise e
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6663, in fn_2
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1727, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/megatron/core/models/gpt/gpt_model.py", line 190, in forward
[rank0]:     hidden_states = self.decoder(
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1716, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1727, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/megatron/core/transformer/transformer_block.py", line 361, in forward
[rank0]:     hidden_states = make_viewless_tensor(
[rank0]:   File "/home/tfogal/env/lib/python3.10/site-packages/megatron/core/utils.py", line 147, in make_viewless_tensor
[rank0]:     if inp._base is None:
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 6060, in _impl
[rank0]:     return fn.__func__(fn.__self__, *args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/interpreter.py", line 1272, in wrapping_wrapper
[rank0]:     res = ufn(*uargs, **ukwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/jit_ext.py", line 704, in wrapper
[rank0]:     return fn(*args, **kwargs)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/proxies.py", line 1312, in __getattr__
[rank0]:     baseutils.check(method_or_value is not None, lambda: f"Unknown attribute {attr}", exception_type=AttributeError)
[rank0]:   File "/home/tfogal/dev/thunder/thunder/core/baseutils.py", line 103, in check
[rank0]:     raise exception_type(s())
[rank0]: AttributeError: Unknown attribute _base. Did you mean: '_return_value'?

Pitch

This is coming from a finetuning case of the Megatron GPT network. To setup and run the network, see #344.

Alternatives / Potential work-arounds

Marking triage review for help

Minimal Repro

(help wanted)

cc @apaz-cli @tfogal

tfogal avatar Jul 10 '24 23:07 tfogal

At first glance, I would probably try making a lookaside for make_viewless_tensor that is a no-op. I don't think we want to model _base and probably even less so for trying to force it to be None in likely unforseen ways to mess with the internals of PyTorch. That said, I wonder if a nicer error message for _base would also be nice.

t-vi avatar Jul 11 '24 05:07 t-vi

At first glance, I would probably try making a lookaside for make_viewless_tensor that is a no-op.

I agree with this. I'm not a fan of the method in question.

crcrpar avatar Jul 11 '24 07:07 crcrpar

  • _base links a tensor to what it actually refers to
  • megatron does "fun" things--tensors that share memory but claim that they are not views. this is what make_viewless_tensor does. avoiding garbage collection?
  • what is megatron doing here / why do they do this? are we fully understanding megatron's goal here?
  • make make_viewless_tensor from megatron a no-op?
  • @tfogal let's talk to megatron
  • https://discuss.pytorch.org/t/variable-execution-engine-run-backward-causes-memory-leak/205546/3

tfogal avatar Jul 15 '24 15:07 tfogal

Marking high priority as it turns out the 'correct' configuration of NeVA actually runs into this (earlier we avoided this problem by running NeVA in a different mode).

tfogal avatar Aug 09 '24 22:08 tfogal

The developer who wrote this is out but I had a great chat with @jaredcasper who was wicked helpful:

  • This is an optimization related to pipeline parallelism.
  • The problem this solves is excessive memory use: once we've fed the output of one pipeline stage to the input of the next stage, we no longer need the outputs of the first pipeline stage. PyTorch was adamant about not deleting tensors from the first stage, though; despite using PyTorch comms, torch couldn't seem to make sense of the pipeline parallelism strategy. This strategy allows megatron to effectively delete the tensor's data (saving memory) without deleting the tensor outright (and thus angering torch).
  • viewless tensors feed back into deallocate_output_tensor, which enables megatron to delete the tensor's data w/o deleting the tensor itself
  • We should check our config. There's a chance (admittedly small) that NeVA and/or other NeMo models never set deallocate_pipeline_output, which would mean it doesn't even use this optimization.
  • The quick-and-dirty approach is, as @t-vi surmised above, to add a lookaside that makes this a no-op. Or, rather, a view or maybe something weird like foo.new_tensor(foo.data_ptr()). The easiest is probably to just clone() the tensor. This will have a negative impact on memory use but should get us running.
    • If/when we implement proper memory planning, that planning would be able to easily figure out the an earlier pipeline stages' tensors are dead just through live range analysis. This would then restore memory the memory optimization.
  • There is a small chance that this optimization is no longer needed. If so, we need to support 24.01 containers but if that version of PyTorch can do the right thing w/o this, then upstream would consider any patches we sent.

Thanks, Jared! Marking triage review to make sure we discuss.

tfogal avatar Aug 21 '24 20:08 tfogal

Playing with this, are you sure that .detach() is not enough (in recent versions of PyTorch at least)?

import torch
print(torch.__version__)
a = torch.randn(100, 100)  # 2.4.0+cu124
b = a.view(100, 100)
assert b._base is a
c = a.detach()
assert c._base is None
assert c.data_ptr() == a.data_ptr()

So I would expect this to be equivalent to the .data business for _kernel_make_viewless_tensor

def _kernel_make_viewless_tensor(inp, requires_grad):
    return inp.detach().requires_grad_(requires_grad)

Because of how we currently fudge torch.autograd.Function, the detach will not be good for Thunder, so I would start with this:

@thunder.core.jit_ext.register_general_jit_lookaside(
        megatron.core.utils.make_viewless_tensor
)
@thunder.core.jit_ext.interpreter_needs_wrap
def make_viewless_tensor_lookaside(inp, requires_grad, keep_graph)
        return inp

for now given that we will be running this in tracing, i.e. with TensorProxies. @tfogal Could you give this a spin please?

t-vi avatar Aug 22 '24 08:08 t-vi

@tfogal @t-vi I gave it a go and the WAR seems to work, however the run still crashes with a type mismatch later on. To me the type mismatch does not seem related to this workaround but I cannot be certain 100% atm. The next error is:

 AssertionError: Data types for parameters must match when outside of autocasted region.  Found input dtype: thunder.dtypes.float32 and 'layer_norm_weight' dtype: thunder.dtypes.bfloat16

And it comes from File Megatron-LM/megatron/core/transformer/transformer_block.py", line 398, in forward hidden_states, context = layer( for the layer defined as a megatron.core.transformer.transformer_layer.TransformerLayer

riccardofelluga avatar Aug 23 '24 13:08 riccardofelluga

Could you give this a spin

I gave it a go and

hah, love when people jump on things while I'm asleep :-)

Riccardo, it's not clear whether the workaround you refer to was TomV's latest reply (of changing megatron to use detach), or of TomV's earlier thought of creating a no-op lookaside for make_viewless_tensor. If it's the latter, would it make sense as a thunder PR?

tfogal avatar Aug 23 '24 15:08 tfogal

The next error is:

I'd agree with you that this is likely unrelated. If you could get an issue filed that would be helpful, but a PR would be more important so that other people could even reach that error.

tfogal avatar Aug 23 '24 15:08 tfogal

Riccardo, it's not clear whether the workaround you refer to was TomV's latest reply (of changing megatron to use detach), or of TomV's earlier thought of creating a no-op lookaside for make_viewless_tensor. If it's the latter, would it make sense as a thunder PR?

@tfogal sorry for the fast write up 😅 so the WAR I tried is the no-op! I'll open an issue with the repro and the latest error

riccardofelluga avatar Aug 26 '24 07:08 riccardofelluga

You can repro the error in this issue by using the instructions in #1044 and commenting out the lookaside WAR

riccardofelluga avatar Aug 26 '24 08:08 riccardofelluga

Riccardo, it's not clear whether the workaround you refer to was TomV's latest reply (of changing megatron to use detach), or of TomV's earlier thought of creating a no-op lookaside for make_viewless_tensor.

@tfogal sorry for the fast write up 😅 so the WAR I tried is the no-op! I'll open an issue with the repro and the latest error

No need to apologize! When we're all back from free days could you send a PR with the no-op lookaside patch you have?

tfogal avatar Sep 11 '24 20:09 tfogal