[BUG] None grad triggers exception in the backward hook
Describe the bug When training Llama2 with DeepCompile enabled, the backward engine seems to pass two losses to the backward graph, with one of them being None and not actually used by the graph.
With #7665, a backward hook is registered to scale the loss by accumulation steps, but that hook assumes the grad is always a valid tensor. As DeepCompile breaks that assumption, the hook will triggere a type error exception.
To Reproduce Steps to reproduce the behavior:
- Run https://gist.github.com/eternalNight/8277302920fb8fb5aa7eb4fb9e3fb6a3
Traceback:
[rank0]: Traceback (most recent call last):
[rank0]: File "/test_simple.py", line 67, in <module>
[rank0]: m.backward(loss.loss)
[rank0]: File "/mnt/engines/deepspeed/deepspeed/utils/nvtx.py", line 20, in wrapped_fn
[rank0]: ret_val = func(*args, **kwargs)
[rank0]: File "/mnt/engines/deepspeed/deepspeed/runtime/engine.py", line 2466, in backward
[rank0]: loss.backward(**backward_kwargs)
[rank0]: File "/workspaces/venv/lib/python3.10/site-packages/torch/_tensor.py", line 639, in backward
[rank0]: return handle_torch_function(
[rank0]: File "/workspaces/venv/lib/python3.10/site-packages/torch/overrides.py", line 1721, in handle_torch_function
[rank0]: result = mode.__torch_function__(public_api, types, args, kwargs)
[rank0]: File "/workspaces/venv/lib/python3.10/site-packages/torch/utils/_device.py", line 104, in __torch_function__
[rank0]: return func(*args, **kwargs)
[rank0]: File "/workspaces/venv/lib/python3.10/site-packages/torch/_tensor.py", line 648, in backward
[rank0]: torch.autograd.backward(
[rank0]: File "/workspaces/venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 353, in backward
[rank0]: _engine_run_backward(
[rank0]: File "/workspaces/venv/lib/python3.10/site-packages/torch/autograd/graph.py", line 824, in _engine_run_backward
[rank0]: return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
[rank0]: File "/mnt/engines/deepspeed/deepspeed/runtime/utils.py", line 1271, in backward_hook
[rank0]: grad = self.preprocess_per_tensor_fn(grad)
[rank0]: File "/mnt/engines/deepspeed/deepspeed/runtime/engine.py", line 2335, in _backward_prologue_per_tensor
[rank0]: return grad / self.gradient_accumulation_steps()
[rank0]: TypeError: unsupported operand type(s) for /: 'NoneType' and 'int'
Expected behavior The training should end normally.
More information The two losses in torch profile:
@tohtana Any idea if this two-grad phenomenon is expected? If so, should we add a None check at the beginning of _backward_prologue_per_tensor?
Thank you for reporting, @eternalNight! I didn't expect the case. Do you think we can simply skip the scaling when the given value is None?
Thank you for reporting, @eternalNight! I didn't expect the case. Do you think we can simply skip the scaling when the given value is None?
@tohtana I'll investigate why there is a second grad in the compiled backward graph to see if it is safe to skip None grads.
@tohtana Here's the story:
- Llama2 returns a
CausalLMOutputWithPast(which extendsdict) which contains a loss tensor (of size 1) and a logits tensor. Deepspeed registers the backward hook on that object, and thus hooks both tensors. - Only when DeepCompile is enabled, the backward hooks of both tensors are called, and the hook on logits is called with a
Nonegrad, probably because it is not scalar.
The hook on logits is not called at all without DeepCompile, regardless of whether torch compile is applied or not. Somehow autograd remembers both tensors and calls their backward hooks before running the backward compile pass.
As None grads is caused by non-scalar tensors, I think it is safe to simply skip them in the hook. What do you think?
Thank you @eternalNight for your investigation! Yes, let's simply skip the none grad in the grad. Can you open a PR for it?
@eternalNight Sorry, as I also wanted to run Llama models, I opened #7725. Can you review it? I confirmed it works with Llama-3 7B.