Issue in Manual optimisation, during self.manual_backward call
Bug description
I have set automatic_optimization to False, and am using self.manual_backward to calculate and populate the gradients. The code breaks during the self.manual_backward call, raising the error "RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn". I have posted the code below for replicating the issue. The issue does not arise when I set args['use_minibatch_clip_loss'] = False, or when I set args['batch_size'] = args['minibatch_size'] = 16. I suspect the issue only arises when I try to do backwards after running the model under torch.no_grad()
What version are you seeing the problem on?
v2.2
How to reproduce the bug
import os
import math
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
import lightning.pytorch as pl
from lightning.pytorch import loggers as pl_loggers
class TestModule(nn.Module):
def __init__(self, in_dim=512, out_dim=16):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.query = nn.Linear(self.in_dim, self.out_dim, bias=True)
def forward(self, input):
return self.query(input)
class TestLitModel(pl.LightningModule):
def __init__(self, args):
super().__init__()
self.test_module_obj = TestModule(args['in_dim'], args['out_dim'])
self.use_minibatch_clip_loss = args['use_minibatch_clip_loss']
if self.use_minibatch_clip_loss:
self.batch_size = args['batch_size']
self.minibatch_size = args['minibatch_size']
self.accumulate_grad_batches_mb = args['accumulate_grad_batches_mb']
self.automatic_optimization = False
def get_mini_batches(self, input):
num_mb = math.ceil(self.batch_size / self.minibatch_size)
return torch.chunk(input, num_mb)
def shared_step(self, input):
output = self.test_module_obj(input)
loss = output.mean()
return loss
def train_step_minibatch(self, input, batch_idx):
if self.batch_size > self.minibatch_size:
mini_batches = self.get_mini_batches(input)
mb_model_output_list = list()
with torch.no_grad():
for mb in mini_batches:
mb_model_output_list.append(self.shared_step(mb).detach())
all_loss = sum(mb_model_output_list)
self.test_module_obj.train()
self.test_module_obj.requires_grad_(True)
torch.set_grad_enabled(True)
assert torch.is_grad_enabled()
assert all(p.requires_grad for p in self.test_module_obj.parameters())
for _, mb in enumerate(mini_batches):
mb_model_output = self.shared_step(mb)
self.manual_backward(mb_model_output)
else:
all_loss = self.shared_step(input)
self.manual_backward(all_loss)
# get optimizers and scheduler
if (batch_idx + 1) % self.accumulate_grad_batches_mb == 0:
optimizer = self.optimizers()
if isinstance(optimizer, list):
optimizer = optimizer[0]
optimizer.step()
optimizer.zero_grad()
return all_loss
def training_step(self, batch, batch_idx):
input = batch[0]
if self.use_minibatch_clip_loss:
loss = self.train_step_minibatch(input, batch_idx)
else:
loss = self.shared_step(input)
return loss
def validation_step(self, batch, batch_idx):
input = batch[0]
loss = self.shared_step(input)
return loss
def configure_optimizers(self):
self.optimizer = torch.optim.AdamW(list(self.test_module_obj.parameters()), lr= 0.0002, weight_decay= 0.01)
return {"optimizer": self.optimizer}
if __name__ == '__main__':
args = {
'in_dim': 512,
'out_dim': 16,
'train_batch_size': 16,
'val_batch_size': 64,
'use_minibatch_clip_loss': True,
'batch_size': 16,
'minibatch_size': 4,
'accumulate_grad_batches_mb': 1,
}
x_dummy = torch.randn(512, args['in_dim']) # 512 samples, args['in_dim'] features each
test_data_loader = DataLoader(TensorDataset(x_dummy), batch_size=args['train_batch_size'], shuffle=False) # Dummy dataset
test_lit_model = TestLitModel(args)
# -- LOGGING
checkpoint_dir = 'test_logs/'
tb_logger = pl_loggers.TensorBoardLogger(save_dir=os.path.join(checkpoint_dir, "logs"))
trainer = pl.Trainer(
logger=tb_logger,
accelerator='gpu',
devices=[1],
strategy='auto',
precision='16-mixed',
max_epochs=1,
accumulate_grad_batches=1,
num_sanity_val_steps=0,
inference_mode=False,
)
trainer.fit(test_lit_model, test_data_loader)
Error messages and logs
Epoch 0: 0%| | 0/32 [00:00<?, ?it/s]Traceback (most recent call last):
File "/home/users/pranav.rao/foundational_models/test1.py", line 119, in <module>
trainer.fit(test_lit_model, test_data_loader)
File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 544, in fit
call._call_and_handle_interrupt(
File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 580, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 987, in _run
results = self._run_stage()
^^^^^^^^^^^^^^^^^
File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1033, in _run_stage
self.fit_loop.run()
File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run
self.advance()
File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance
self.epoch_loop.run(self._data_fetcher)
File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 140, in run
self.advance(data_fetcher)
File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 252, in advance
batch_output = self.manual_optimization.run(kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/manual.py", line 94, in run
self.advance(kwargs)
File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/loops/optimization/manual.py", line 114, in advance
training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 309, in _call_strategy_hook
output = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 391, in training_step
return self.lightning_module.training_step(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/users/pranav.rao/foundational_models/test1.py", line 74, in training_step
loss = self.train_step_minibatch(input, batch_idx)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/users/pranav.rao/foundational_models/test1.py", line 57, in train_step_minibatch
self.manual_backward(mb_model_output)
File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/core/module.py", line 1071, in manual_backward
self.trainer.strategy.backward(loss, None, *args, **kwargs)
File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 213, in backward
self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, *args, **kwargs)
File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/plugins/precision/precision.py", line 72, in backward
model.backward(tensor, *args, **kwargs)
File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/lightning/pytorch/core/module.py", line 1090, in backward
loss.backward(*args, **kwargs)
File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/torch/_tensor.py", line 525, in backward
torch.autograd.backward(
File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/torch/autograd/__init__.py", line 267, in backward
_engine_run_backward(
File "/home/users/pranav.rao/miniconda3/envs/pytorch_lightning_bug_report/lib/python3.11/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
Environment
Current environment
- CUDA: - GPU: - NVIDIA A100 80GB PCIe - NVIDIA A100 80GB PCIe - NVIDIA A100 80GB PCIe - NVIDIA A100 80GB PCIe - available: True - version: 12.1
- Lightning: - lightning: 2.2.3 - lightning-utilities: 0.11.2 - pytorch-lightning: 2.2.3 - torch: 2.3.0 - torchmetrics: 1.3.2
- Packages: - absl-py: 2.1.0 - aiohttp: 3.9.5 - aiosignal: 1.3.1 - attrs: 23.2.0 - filelock: 3.13.4 - frozenlist: 1.4.1 - fsspec: 2024.3.1 - grpcio: 1.62.2 - idna: 3.7 - jinja2: 3.1.3 - lightning: 2.2.3 - lightning-utilities: 0.11.2 - markdown: 3.6 - markupsafe: 2.1.5 - mpmath: 1.3.0 - multidict: 6.0.5 - networkx: 3.3 - numpy: 1.26.4 - nvidia-cublas-cu12: 12.1.3.1 - nvidia-cuda-cupti-cu12: 12.1.105 - nvidia-cuda-nvrtc-cu12: 12.1.105 - nvidia-cuda-runtime-cu12: 12.1.105 - nvidia-cudnn-cu12: 8.9.2.26 - nvidia-cufft-cu12: 11.0.2.54 - nvidia-curand-cu12: 10.3.2.106 - nvidia-cusolver-cu12: 11.4.5.107 - nvidia-cusparse-cu12: 12.1.0.106 - nvidia-nccl-cu12: 2.20.5 - nvidia-nvjitlink-cu12: 12.4.127 - nvidia-nvtx-cu12: 12.1.105 - packaging: 24.0 - pip: 23.3.1 - protobuf: 5.26.1 - pytorch-lightning: 2.2.3 - pyyaml: 6.0.1 - setuptools: 68.2.2 - six: 1.16.0 - sympy: 1.12 - tensorboard: 2.16.2 - tensorboard-data-server: 0.7.2 - torch: 2.3.0 - torchmetrics: 1.3.2 - tqdm: 4.66.2 - triton: 2.3.0 - typing-extensions: 4.11.0 - werkzeug: 3.0.2 - wheel: 0.41.2 - yarl: 1.9.4
- System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.11.9 - release: 5.15.0-69-generic - version: #76~20.04.1-Ubuntu SMP Mon Mar 20 15:54:19 UTC 2023
More info
I am training a Vision Language model with CLIP loss. The batch size I want to use is large, which requires to calculate the embeddings in mini batches and then calculate the gradient in mini batches as done in the repo https://github.com/Zasder3/train-CLIP/tree/main (See lines: https://github.com/Zasder3/train-CLIP/blob/79d4c7960072047a9e0d39335ab60dcb150640c3/models/wrapper.py#L64-L109 )
The issue arose when I implemented the similar algorithm as above for my use case and tried to train it. I have tried to isolate the problem as much I could, and produce a simple script reproducing the same error I get.
cc @carmocca @justusschock @awaelchli
I removed excess code, made a new Conda environment, installing just pytorch-lightning and tensorboard, and was able to replicate the same issue even with lightning version 2.2.3. I have edited the above issue to reflect the same.
@pranavrao-qure Here I made the same PyTorch code (no Lightning) to show that this results in the same error:
import math
import torch
from torch import nn, GradScaler
from torch.utils.data import TensorDataset, DataLoader
class TestModule(nn.Module):
def __init__(self, in_dim=512, out_dim=16):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.query = nn.Linear(self.in_dim, self.out_dim, bias=True)
def forward(self, input):
return self.query(input)
def get_mini_batches(input):
num_mb = math.ceil(16 / 4)
return torch.chunk(input, num_mb)
if __name__ == '__main__':
test_data_loader = DataLoader(TensorDataset(torch.randn(512, 512)), batch_size=16, shuffle=False)
test_module_obj = TestModule()
scaler = GradScaler(device="cpu")
with torch.autocast(device_type='cpu', dtype=torch.float16):
batch = next(iter(test_data_loader))
input = batch[0]
mini_batches = get_mini_batches(input)
mb_model_output_list = list()
with torch.no_grad():
for mb in mini_batches:
mb_model_output_list.append(test_module_obj(mb).mean().detach())
all_loss = sum(mb_model_output_list)
test_module_obj.train()
test_module_obj.requires_grad_(True)
torch.set_grad_enabled(True)
assert torch.is_grad_enabled()
assert all(p.requires_grad for p in test_module_obj.parameters())
for _, mb in enumerate(mini_batches):
mb_model_output = test_module_obj(mb).mean()
scaler.scale(mb_model_output).backward()
When you use with torch.no_grad(), you also need to disable autocast when you use mixed precision. Like so in your training step:
with torch.no_grad(), torch.autocast(device_type=self.device.type, enabled=False):
...
This seems to be a quirk with PyTorch and how these context managers interact. There is nothing that could be done on the Lightning side to my knowledge.
It also works if u do this
import math
import torch
from torch import nn, GradScaler
from torch.utils.data import TensorDataset, DataLoader
class TestModule(nn.Module):
def __init__(self, in_dim=512, out_dim=16):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.query = nn.Linear(self.in_dim, self.out_dim, bias=True)
def forward(self, input):
return self.query(input)
def get_mini_batches(input):
num_mb = math.ceil(16 / 4)
return torch.chunk(input, num_mb)
if __name__ == '__main__':
test_data_loader = DataLoader(TensorDataset(torch.randn(512, 512)), batch_size=16, shuffle=False)
test_module_obj = TestModule()
scaler = GradScaler(device="cpu")
with torch.autocast(device_type='cpu', dtype=torch.float16):
batch = next(iter(test_data_loader))
input = batch[0]
mini_batches = get_mini_batches(input)
mb_model_output_list = list()
with torch.no_grad():
for mb in mini_batches:
mb_model_output_list.append(test_module_obj(mb).mean().detach())
all_loss = sum(mb_model_output_list)
test_module_obj.train()
test_module_obj.requires_grad_(True)
torch.set_grad_enabled(True)
assert torch.is_grad_enabled()
assert all(p.requires_grad for p in test_module_obj.parameters())
with torch.autocast(device_type='cpu', dtype=torch.float16):
for _, mb in enumerate(mini_batches):
mb_model_output = test_module_obj(mb).mean()
scaler.scale(mb_model_output).backward()
which mean the pytorch lightning encloses the entire module with torch.autocase enabled if u do precision=16 (which by default it mixed precision). The behavior you observe happens because you do both a no_grad forward pass and a grad-enabled forward pass within the same autocast context. In the no_grad forward pass, FP16 param copies are created and cached. Because it’s a no_grad context, when these FP16 copies are created they have requires_grad=False. When you run net(input) again in a grad-exposed way, you are still within the same autocast context, so the cache is live and the FP16 copies are not recreated (instead, net's FP16list ops use the cached copies). Since these cached copies have requires_grad=False, net(input) does not build an autograd graph, and z ends up having requires_grad=False. can read more here : https://discuss.pytorch.org/t/autocast-and-torch-no-grad-unexpected-behaviour/93475/3