pytorch-lightning
pytorch-lightning copied to clipboard
trainer test and validate have issues with autograd
Bug description
Hi,
It seems that trainer.test and trainer.validate have some issues with manual grading while trainer.fit not!
In my code trainer.fit(train_episode_loader, val_task_loader) works but none of trainer.test(m, val_task_loader) and trainer.validate(m, val_task_loader) does.
I tried both higher library and functional. I keep getting the same error.
What version are you seeing the problem on?
v2.4
How to reproduce the bug
import torch
import torch.nn as nn
import torch.optim as optim
from copy import deepcopy
from tqdm.auto import tqdm
import torch.nn.functional as F
from torch.func import grad, functional_call
try:
import higher
except ModuleNotFoundError:
!{sys.executable} -m pip install --quiet higher
import higher
class MAMLModel(pl.LightningModule):
def __init__(self, inner_lr, meta_lr, inner_steps, n_way, inner_reg=0.0005, meta_reg=0.0005):
super(MAMLModel, self).__init__()
self.encoder = autoencoder.encoder
self.latent_dim = autoencoder.latent_dim
self.classifier = nn.Linear(self.latent_dim, n_way)
self.validation_step_outputs = []
self.test_step_outputs = []
self.automatic_optimization=False
self.save_hyperparameters()
def forward(self, x):
emb = self.encoder(x)
out = self.classifier(emb)
return out
@staticmethod
def get_targets_from_labels(labels):
classes = torch.unique(labels)
targets = (classes[None, :] == labels[:, None]).long().argmax(dim=-1)
return targets
def step_higher(self, support_images, support_targets, query_images, query_targets):
inner_opt = torch.optim.SGD(self.parameters(), lr=self.hparams.inner_lr)
with higher.innerloop_ctx(self, inner_opt, copy_initial_weights=False) as (fmodel, diffopt):
# Inner loop
for _ in range(self.hparams.inner_steps):
support_outputs = fmodel(support_images)
support_loss = F.cross_entropy(support_outputs, support_targets)
support_loss += self.hparams.inner_reg * sum(p.pow(2).mean() for n, p in fmodel.named_parameters() if p.requires_grad and 'weight' in n)
self.log('val_support_loss', support_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
diffopt.step(loss)
# Compute query loss
query_outputs = fmodel(query_images)
return query_outputs
def step_functional(self, support_images, support_targets, query_images, query_targets):
# Copy the initial parameters
task_params = {name: param.clone() for name, param in self.named_parameters()}
# Inner loop
for _ in range(self.hparams.inner_steps):
support_outputs = functional_call(self, task_params, support_images)
support_loss = F.cross_entropy(support_outputs, support_targets)
support_loss += self.hparams.inner_reg * sum(p.pow(2).mean() for n, p in task_params.items() if p.requires_grad and 'weight' in n)
self.log('val_support_loss', support_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
# Compute gradients
grads = torch.autograd.grad(support_loss, task_params.values(), create_graph=True)
# grads = torch.autograd.grad(support_loss, [p for p in task_params.values() if p.requires_grad], create_graph=True)
# clamp grads
grads = [torch.clamp(grad, min=-5.0, max=5.0) for grad in grads]
# Update parameters
task_params = {name: param - self.hparams.inner_lr * grad
for (name, param), grad in zip(task_params.items(), grads)}
# idx = 0
# for name, param in task_params.items():
# if param.requires_grad:
# task_params[name] = param - (self.hparams.inner_lr * grads[idx])
# idx += 1
# Compute query loss
query_outputs = functional_call(self, task_params, query_images)
return query_outputs
def step(self, *args):
return self.step_higher(*args)
def training_step(self, batch, batch_idx):
meta_loss = 0.
for (support_images, support_labels), (query_images, query_labels) in batch:
support_images, support_targets = support_images.to(self.device), MAMLModel.get_targets_from_labels(support_labels).to(self.device)
query_images, query_targets = query_images.to(self.device), MAMLModel.get_targets_from_labels(query_labels).to(self.device)
query_outputs = self.step(support_images, support_targets, query_images, query_targets)
query_loss = F.cross_entropy(query_outputs, query_targets)
meta_loss += query_loss
opt = self.optimizers()
opt.zero_grad()
self.manual_backward(meta_loss)
opt.step()
self.log('train_query_loss', meta_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
def validation_step(self, batch, batch_idx):
self.train()
with torch.set_grad_enabled(True):
(support_images, support_labels), (query_images, query_labels) = batch
support_images, support_targets = support_images.to(self.device), MAMLModel.get_targets_from_labels(support_labels).to(self.device)
query_images, query_targets = query_images.to(self.device), MAMLModel.get_targets_from_labels(query_labels).to(self.device)
query_outputs = self.step(support_images, support_targets, query_images, query_targets)
self.validation_step_outputs.append({'preds': query_outputs, 'targets': query_targets})
def on_validation_epoch_end(self):
preds = torch.concat([x['preds'] for x in self.validation_step_outputs], dim=0)
targets = torch.concat([x['targets'] for x in self.validation_step_outputs], dim=0)
loss = F.cross_entropy(preds, targets)
acc = (preds.argmax(dim=-1) == targets).float().mean()
self.log('val_acc', acc, on_epoch=True, prog_bar=True, logger=True)
self.log('val_loss', loss, on_epoch=True, prog_bar=True, logger=True)
self.validation_step_outputs.clear() # free memory
def test_step(self, batch, batch_idx):
self.train()
with torch.set_grad_enabled(True):
(support_images, support_labels), (query_images, query_labels) = batch
support_images, support_targets = support_images.to(self.device), MAMLModel.get_targets_from_labels(support_labels).to(self.device)
query_images, query_targets = query_images.to(self.device), MAMLModel.get_targets_from_labels(query_labels).to(self.device)
query_outputs = self.step(support_images, support_targets, query_images, query_targets)
self.test_step_outputs.append({'preds': query_outputs, 'targets': query_targets})
def on_test_epoch_end(self):
preds = torch.concat([x['preds'] for x in self.validation_step_outputs], dim=0)
targets = torch.concat([x['targets'] for x in self.validation_step_outputs], dim=0)
loss = F.cross_entropy(preds, targets)
acc = (preds.argmax(dim=-1) == targets).float().mean()
self.log('test_acc', acc, on_epoch=True, prog_bar=True, logger=True)
self.log('test_loss', loss, on_epoch=True, prog_bar=True, logger=True)
self.test_step_outputs.clear() # free memory
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=self.hparams.meta_lr)
# Hyperparameters
inner_lr = 0.05 # 0.05
meta_lr = 0.0005 # 0.0005
inner_steps = 10
num_epochs = 10
m = MAMLModel(inner_lr, meta_lr, inner_steps, n_way, inner_reg=0.0005, meta_reg=0.0005)
# Model checkpoint callback
checkpoint_callback = ModelCheckpoint(
monitor='val_query_loss',
dirpath='monitor/',
filename='best_maml',
save_top_k=3,
mode='min'
)
# Initialize Trainer
trainer = Trainer(
max_epochs=num_epochs,
callbacks=[checkpoint_callback],
accelerator='gpu' if torch.cuda.is_available() else 'cpu',
)
# Train the model
trainer.fit(m, train_episode_loader, val_task_loader)
# Test the model
trainer.test(m, val_task_loader)
Error messages and logs
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[156], line 171
164 trainer = Trainer(
165 max_epochs=num_epochs,
166 callbacks=[checkpoint_callback],
167 accelerator='gpu' if torch.cuda.is_available() else 'cpu',
168 )
170 # Train the model
--> 171 trainer.fit(m, train_episode_loader, val_task_loader)
173 # Test the model
174 trainer.test(m, val_task_loader)
File [~\pythonEnv\lib\site-packages\pytorch_lightning\trainer\trainer.py:538](http://localhost:8888/lab/tree/src/few-shot-learning/~/pythonEnv/lib/site-packages/pytorch_lightning/trainer/trainer.py#line=537), in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
536 self.state.status = TrainerStatus.RUNNING
537 self.training = True
--> 538 call._call_and_handle_interrupt(
539 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
540 )
File [~\pythonEnv\lib\site-packages\pytorch_lightning\trainer\call.py:47](http://localhost:8888/lab/tree/src/few-shot-learning/~/pythonEnv/lib/site-packages/pytorch_lightning/trainer/call.py#line=46), in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
45 if trainer.strategy.launcher is not None:
46 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 47 return trainer_fn(*args, **kwargs)
49 except _TunerExitException:
50 _call_teardown_hook(trainer)
File [~\pythonEnv\lib\site-packages\pytorch_lightning\trainer\trainer.py:574](http://localhost:8888/lab/tree/src/few-shot-learning/~/pythonEnv/lib/site-packages/pytorch_lightning/trainer/trainer.py#line=573), in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
567 assert self.state.fn is not None
568 ckpt_path = self._checkpoint_connector._select_ckpt_path(
569 self.state.fn,
570 ckpt_path,
571 model_provided=True,
572 model_connected=self.lightning_module is not None,
573 )
--> 574 self._run(model, ckpt_path=ckpt_path)
576 assert self.state.stopped
577 self.training = False
File [~\pythonEnv\lib\site-packages\pytorch_lightning\trainer\trainer.py:981](http://localhost:8888/lab/tree/src/few-shot-learning/~/pythonEnv/lib/site-packages/pytorch_lightning/trainer/trainer.py#line=980), in Trainer._run(self, model, ckpt_path)
976 self._signal_connector.register_signal_handlers()
978 # ----------------------------
979 # RUN THE TRAINER
980 # ----------------------------
--> 981 results = self._run_stage()
983 # ----------------------------
984 # POST-Training CLEAN UP
985 # ----------------------------
986 log.debug(f"{self.__class__.__name__}: trainer tearing down")
File [~\pythonEnv\lib\site-packages\pytorch_lightning\trainer\trainer.py:1023](http://localhost:8888/lab/tree/src/few-shot-learning/~/pythonEnv/lib/site-packages/pytorch_lightning/trainer/trainer.py#line=1022), in Trainer._run_stage(self)
1021 if self.training:
1022 with isolate_rng():
-> 1023 self._run_sanity_check()
1024 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
1025 self.fit_loop.run()
File [~\pythonEnv\lib\site-packages\pytorch_lightning\trainer\trainer.py:1052](http://localhost:8888/lab/tree/src/few-shot-learning/~/pythonEnv/lib/site-packages/pytorch_lightning/trainer/trainer.py#line=1051), in Trainer._run_sanity_check(self)
1049 call._call_callback_hooks(self, "on_sanity_check_start")
1051 # run eval step
-> 1052 val_loop.run()
1054 call._call_callback_hooks(self, "on_sanity_check_end")
1056 # reset logger connector
File [~\pythonEnv\lib\site-packages\pytorch_lightning\loops\utilities.py:178](http://localhost:8888/lab/tree/src/few-shot-learning/~/pythonEnv/lib/site-packages/pytorch_lightning/loops/utilities.py#line=177), in _no_grad_context.<locals>._decorator(self, *args, **kwargs)
176 context_manager = torch.no_grad
177 with context_manager():
--> 178 return loop_run(self, *args, **kwargs)
File [~\pythonEnv\lib\site-packages\pytorch_lightning\loops\evaluation_loop.py:135](http://localhost:8888/lab/tree/src/few-shot-learning/~/pythonEnv/lib/site-packages/pytorch_lightning/loops/evaluation_loop.py#line=134), in _EvaluationLoop.run(self)
133 self.batch_progress.is_last_batch = data_fetcher.done
134 # run step hooks
--> 135 self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
136 except StopIteration:
137 # this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support
138 break
File [~\pythonEnv\lib\site-packages\pytorch_lightning\loops\evaluation_loop.py:396](http://localhost:8888/lab/tree/src/few-shot-learning/~/pythonEnv/lib/site-packages/pytorch_lightning/loops/evaluation_loop.py#line=395), in _EvaluationLoop._evaluation_step(self, batch, batch_idx, dataloader_idx, dataloader_iter)
390 hook_name = "test_step" if trainer.testing else "validation_step"
391 step_args = (
392 self._build_step_args_from_hook_kwargs(hook_kwargs, hook_name)
393 if not using_dataloader_iter
394 else (dataloader_iter,)
395 )
--> 396 output = call._call_strategy_hook(trainer, hook_name, *step_args)
398 self.batch_progress.increment_processed()
400 if using_dataloader_iter:
401 # update the hook kwargs now that the step method might have consumed the iterator
File [~\pythonEnv\lib\site-packages\pytorch_lightning\trainer\call.py:319](http://localhost:8888/lab/tree/src/few-shot-learning/~/pythonEnv/lib/site-packages/pytorch_lightning/trainer/call.py#line=318), in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
316 return None
318 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 319 output = fn(*args, **kwargs)
321 # restore current_fx when nested context
322 pl_module._current_fx_name = prev_fx_name
File [~\pythonEnv\lib\site-packages\pytorch_lightning\strategies\strategy.py:411](http://localhost:8888/lab/tree/src/few-shot-learning/~/pythonEnv/lib/site-packages/pytorch_lightning/strategies/strategy.py#line=410), in Strategy.validation_step(self, *args, **kwargs)
409 if self.model != self.lightning_module:
410 return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
--> 411 return self.lightning_module.validation_step(*args, **kwargs)
Cell In[156], line 111, in MAMLModel.validation_step(self, batch, batch_idx)
109 support_images, support_targets = support_images.to(self.device), MAMLModel.get_targets_from_labels(support_labels).to(self.device)
110 query_images, query_targets = query_images.to(self.device), MAMLModel.get_targets_from_labels(query_labels).to(self.device)
--> 111 query_outputs = self.step(support_images, support_targets, query_images, query_targets)
112 self.validation_step_outputs.append({'preds': query_outputs, 'targets': query_targets})
Cell In[156], line 87, in MAMLModel.step(self, *args)
86 def step(self, *args):
---> 87 return self.step_higher(*args)
Cell In[156], line 46, in MAMLModel.step_higher(self, support_images, support_targets, query_images, query_targets)
44 support_loss += self.hparams.inner_reg * sum(p.pow(2).mean() for n, p in fmodel.named_parameters() if p.requires_grad and 'weight' in n)
45 self.log('val_support_loss', support_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
---> 46 diffopt.step(loss)
49 # Compute query loss
50 query_outputs = fmodel(query_images)
File [~\pythonEnv\lib\site-packages\higher\optim.py:229](http://localhost:8888/lab/tree/src/few-shot-learning/~/pythonEnv/lib/site-packages/higher/optim.py#line=228), in DifferentiableOptimizer.step(self, loss, params, override, grad_callback, **kwargs)
223 # This allows us to gracefully deal with cases where params are frozen.
224 grad_targets = [
225 p if p.requires_grad else _torch.tensor([], requires_grad=True)
226 for p in params
227 ]
--> 229 all_grads = _torch.autograd.grad(
230 loss,
231 grad_targets,
232 create_graph=self._track_higher_grads,
233 allow_unused=True # boo
234 )
236 if grad_callback is not None:
237 all_grads = grad_callback(all_grads)
File [~\pythonEnv\lib\site-packages\torch\autograd\__init__.py:394](http://localhost:8888/lab/tree/src/few-shot-learning/~/pythonEnv/lib/site-packages/torch/autograd/__init__.py#line=393), in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused, is_grads_batched, materialize_grads)
390 result = _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)(
391 grad_outputs_
392 )
393 else:
--> 394 result = Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
395 t_outputs,
396 grad_outputs_,
397 retain_graph,
398 create_graph,
399 t_inputs,
400 allow_unused,
401 accumulate_grad=False,
402 ) # Calls into the C++ engine to run the backward pass
403 if materialize_grads:
404 result = tuple(
405 output
406 if output is not None
407 else torch.zeros_like(input, requires_grad=True)
408 for (output, input) in zip(result, t_inputs)
409 )
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
Environment
Current environment
#- PyTorch Lightning Version: 2.4.0
#- PyTorch Version: 2.1.1+cpu
More info
No response