pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

trainer test and validate have issues with autograd

Open bpfrd opened this issue 1 year ago • 0 comments

Bug description

Hi, It seems that trainer.test and trainer.validate have some issues with manual grading while trainer.fit not! In my code trainer.fit(train_episode_loader, val_task_loader) works but none of trainer.test(m, val_task_loader) and trainer.validate(m, val_task_loader) does. I tried both higher library and functional. I keep getting the same error.

What version are you seeing the problem on?

v2.4

How to reproduce the bug

import torch
import torch.nn as nn
import torch.optim as optim
from copy import deepcopy
from tqdm.auto import tqdm
import torch.nn.functional as F
from torch.func import grad, functional_call
try:
    import higher
except ModuleNotFoundError:
    !{sys.executable} -m pip install --quiet higher
import higher

class MAMLModel(pl.LightningModule):
    def __init__(self, inner_lr, meta_lr, inner_steps, n_way, inner_reg=0.0005, meta_reg=0.0005):
        super(MAMLModel, self).__init__()
        
        self.encoder = autoencoder.encoder
        self.latent_dim = autoencoder.latent_dim
        self.classifier = nn.Linear(self.latent_dim, n_way)
        self.validation_step_outputs = []
        self.test_step_outputs = []
        self.automatic_optimization=False
        self.save_hyperparameters()

    def forward(self, x):
        emb = self.encoder(x)
        out = self.classifier(emb)
        return out

    @staticmethod
    def get_targets_from_labels(labels):
        classes = torch.unique(labels)
        targets = (classes[None, :] == labels[:, None]).long().argmax(dim=-1)
        return targets

    def step_higher(self, support_images, support_targets, query_images, query_targets):
        inner_opt = torch.optim.SGD(self.parameters(), lr=self.hparams.inner_lr)
        with higher.innerloop_ctx(self, inner_opt, copy_initial_weights=False) as (fmodel, diffopt):
            # Inner loop
            for _ in range(self.hparams.inner_steps):
                support_outputs = fmodel(support_images)
                support_loss = F.cross_entropy(support_outputs, support_targets)
                support_loss += self.hparams.inner_reg * sum(p.pow(2).mean() for n, p in fmodel.named_parameters() if p.requires_grad and 'weight' in n)
                self.log('val_support_loss', support_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
                diffopt.step(loss)
                

            # Compute query loss
            query_outputs = fmodel(query_images)
            
        return query_outputs
        
    def step_functional(self, support_images, support_targets, query_images, query_targets):
        
        # Copy the initial parameters
        task_params = {name: param.clone() for name, param in self.named_parameters()}

        # Inner loop
        for _ in range(self.hparams.inner_steps):
            support_outputs = functional_call(self, task_params, support_images)
            support_loss = F.cross_entropy(support_outputs, support_targets)
            support_loss += self.hparams.inner_reg * sum(p.pow(2).mean() for n, p in task_params.items() if p.requires_grad and 'weight' in n)
            self.log('val_support_loss', support_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
            
            # Compute gradients
            grads = torch.autograd.grad(support_loss, task_params.values(), create_graph=True)
            # grads = torch.autograd.grad(support_loss, [p for p in task_params.values() if p.requires_grad], create_graph=True)

            # clamp grads
            grads = [torch.clamp(grad, min=-5.0, max=5.0) for grad in grads]
            
            # Update parameters
            task_params = {name: param - self.hparams.inner_lr * grad
                          for (name, param), grad in zip(task_params.items(), grads)}
            # idx = 0
            # for name, param in task_params.items():
            #     if param.requires_grad:
            #         task_params[name] = param - (self.hparams.inner_lr * grads[idx])
            #         idx += 1
    
        # Compute query loss
        query_outputs = functional_call(self, task_params, query_images)
        return query_outputs

    def step(self, *args):
        return self.step_higher(*args)
                    
    def training_step(self, batch, batch_idx):
        meta_loss = 0.
        for (support_images, support_labels), (query_images, query_labels) in batch:
            support_images, support_targets = support_images.to(self.device), MAMLModel.get_targets_from_labels(support_labels).to(self.device)
            query_images, query_targets = query_images.to(self.device), MAMLModel.get_targets_from_labels(query_labels).to(self.device)
            query_outputs = self.step(support_images, support_targets, query_images, query_targets)
            query_loss = F.cross_entropy(query_outputs, query_targets)
            meta_loss += query_loss
        
        opt = self.optimizers()
        opt.zero_grad()
        self.manual_backward(meta_loss)
        opt.step()
        self.log('train_query_loss', meta_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

    def validation_step(self, batch, batch_idx):
        self.train()
        with torch.set_grad_enabled(True):
        
         (support_images, support_labels), (query_images, query_labels) = batch
         support_images, support_targets = support_images.to(self.device), MAMLModel.get_targets_from_labels(support_labels).to(self.device)
         query_images, query_targets = query_images.to(self.device), MAMLModel.get_targets_from_labels(query_labels).to(self.device)
         query_outputs = self.step(support_images, support_targets, query_images, query_targets)
         self.validation_step_outputs.append({'preds': query_outputs, 'targets': query_targets})

    def on_validation_epoch_end(self):
        preds = torch.concat([x['preds'] for x in self.validation_step_outputs], dim=0)
        targets = torch.concat([x['targets'] for x in self.validation_step_outputs], dim=0)
        loss = F.cross_entropy(preds, targets)
        acc = (preds.argmax(dim=-1) == targets).float().mean()
        self.log('val_acc', acc, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_loss', loss, on_epoch=True, prog_bar=True, logger=True)
        self.validation_step_outputs.clear()  # free memory

    def test_step(self, batch, batch_idx):
        self.train()
        with torch.set_grad_enabled(True):
        
          (support_images, support_labels), (query_images, query_labels) = batch
          support_images, support_targets = support_images.to(self.device), MAMLModel.get_targets_from_labels(support_labels).to(self.device)
          query_images, query_targets = query_images.to(self.device), MAMLModel.get_targets_from_labels(query_labels).to(self.device)
          query_outputs = self.step(support_images, support_targets, query_images, query_targets)
          self.test_step_outputs.append({'preds': query_outputs, 'targets': query_targets})

    def on_test_epoch_end(self):
        preds = torch.concat([x['preds'] for x in self.validation_step_outputs], dim=0)
        targets = torch.concat([x['targets'] for x in self.validation_step_outputs], dim=0)
        loss = F.cross_entropy(preds, targets)
        acc = (preds.argmax(dim=-1) == targets).float().mean()
        self.log('test_acc', acc, on_epoch=True, prog_bar=True, logger=True)
        self.log('test_loss', loss, on_epoch=True, prog_bar=True, logger=True)
        self.test_step_outputs.clear()  # free memory
        
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.hparams.meta_lr)


# Hyperparameters
inner_lr = 0.05 # 0.05
meta_lr = 0.0005 # 0.0005
inner_steps = 10
num_epochs = 10

m = MAMLModel(inner_lr, meta_lr, inner_steps, n_way, inner_reg=0.0005, meta_reg=0.0005)

# Model checkpoint callback
checkpoint_callback = ModelCheckpoint(
    monitor='val_query_loss',
    dirpath='monitor/',
    filename='best_maml',
    save_top_k=3,
    mode='min'
)

# Initialize Trainer
trainer = Trainer(
    max_epochs=num_epochs,
    callbacks=[checkpoint_callback],
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
)

# Train the model
trainer.fit(m, train_episode_loader, val_task_loader)

# Test the model
trainer.test(m, val_task_loader)

Error messages and logs

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[156], line 171
    164 trainer = Trainer(
    165     max_epochs=num_epochs,
    166     callbacks=[checkpoint_callback],
    167     accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    168 )
    170 # Train the model
--> 171 trainer.fit(m, train_episode_loader, val_task_loader)
    173 # Test the model
    174 trainer.test(m, val_task_loader)

File [~\pythonEnv\lib\site-packages\pytorch_lightning\trainer\trainer.py:538](http://localhost:8888/lab/tree/src/few-shot-learning/~/pythonEnv/lib/site-packages/pytorch_lightning/trainer/trainer.py#line=537), in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    536 self.state.status = TrainerStatus.RUNNING
    537 self.training = True
--> 538 call._call_and_handle_interrupt(
    539     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    540 )

File [~\pythonEnv\lib\site-packages\pytorch_lightning\trainer\call.py:47](http://localhost:8888/lab/tree/src/few-shot-learning/~/pythonEnv/lib/site-packages/pytorch_lightning/trainer/call.py#line=46), in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     45     if trainer.strategy.launcher is not None:
     46         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 47     return trainer_fn(*args, **kwargs)
     49 except _TunerExitException:
     50     _call_teardown_hook(trainer)

File [~\pythonEnv\lib\site-packages\pytorch_lightning\trainer\trainer.py:574](http://localhost:8888/lab/tree/src/few-shot-learning/~/pythonEnv/lib/site-packages/pytorch_lightning/trainer/trainer.py#line=573), in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    567 assert self.state.fn is not None
    568 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    569     self.state.fn,
    570     ckpt_path,
    571     model_provided=True,
    572     model_connected=self.lightning_module is not None,
    573 )
--> 574 self._run(model, ckpt_path=ckpt_path)
    576 assert self.state.stopped
    577 self.training = False

File [~\pythonEnv\lib\site-packages\pytorch_lightning\trainer\trainer.py:981](http://localhost:8888/lab/tree/src/few-shot-learning/~/pythonEnv/lib/site-packages/pytorch_lightning/trainer/trainer.py#line=980), in Trainer._run(self, model, ckpt_path)
    976 self._signal_connector.register_signal_handlers()
    978 # ----------------------------
    979 # RUN THE TRAINER
    980 # ----------------------------
--> 981 results = self._run_stage()
    983 # ----------------------------
    984 # POST-Training CLEAN UP
    985 # ----------------------------
    986 log.debug(f"{self.__class__.__name__}: trainer tearing down")

File [~\pythonEnv\lib\site-packages\pytorch_lightning\trainer\trainer.py:1023](http://localhost:8888/lab/tree/src/few-shot-learning/~/pythonEnv/lib/site-packages/pytorch_lightning/trainer/trainer.py#line=1022), in Trainer._run_stage(self)
   1021 if self.training:
   1022     with isolate_rng():
-> 1023         self._run_sanity_check()
   1024     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
   1025         self.fit_loop.run()

File [~\pythonEnv\lib\site-packages\pytorch_lightning\trainer\trainer.py:1052](http://localhost:8888/lab/tree/src/few-shot-learning/~/pythonEnv/lib/site-packages/pytorch_lightning/trainer/trainer.py#line=1051), in Trainer._run_sanity_check(self)
   1049 call._call_callback_hooks(self, "on_sanity_check_start")
   1051 # run eval step
-> 1052 val_loop.run()
   1054 call._call_callback_hooks(self, "on_sanity_check_end")
   1056 # reset logger connector

File [~\pythonEnv\lib\site-packages\pytorch_lightning\loops\utilities.py:178](http://localhost:8888/lab/tree/src/few-shot-learning/~/pythonEnv/lib/site-packages/pytorch_lightning/loops/utilities.py#line=177), in _no_grad_context.<locals>._decorator(self, *args, **kwargs)
    176     context_manager = torch.no_grad
    177 with context_manager():
--> 178     return loop_run(self, *args, **kwargs)

File [~\pythonEnv\lib\site-packages\pytorch_lightning\loops\evaluation_loop.py:135](http://localhost:8888/lab/tree/src/few-shot-learning/~/pythonEnv/lib/site-packages/pytorch_lightning/loops/evaluation_loop.py#line=134), in _EvaluationLoop.run(self)
    133     self.batch_progress.is_last_batch = data_fetcher.done
    134     # run step hooks
--> 135     self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
    136 except StopIteration:
    137     # this needs to wrap the `*_step` call too (not just `next`) for `dataloader_iter` support
    138     break

File [~\pythonEnv\lib\site-packages\pytorch_lightning\loops\evaluation_loop.py:396](http://localhost:8888/lab/tree/src/few-shot-learning/~/pythonEnv/lib/site-packages/pytorch_lightning/loops/evaluation_loop.py#line=395), in _EvaluationLoop._evaluation_step(self, batch, batch_idx, dataloader_idx, dataloader_iter)
    390 hook_name = "test_step" if trainer.testing else "validation_step"
    391 step_args = (
    392     self._build_step_args_from_hook_kwargs(hook_kwargs, hook_name)
    393     if not using_dataloader_iter
    394     else (dataloader_iter,)
    395 )
--> 396 output = call._call_strategy_hook(trainer, hook_name, *step_args)
    398 self.batch_progress.increment_processed()
    400 if using_dataloader_iter:
    401     # update the hook kwargs now that the step method might have consumed the iterator

File [~\pythonEnv\lib\site-packages\pytorch_lightning\trainer\call.py:319](http://localhost:8888/lab/tree/src/few-shot-learning/~/pythonEnv/lib/site-packages/pytorch_lightning/trainer/call.py#line=318), in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
    316     return None
    318 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 319     output = fn(*args, **kwargs)
    321 # restore current_fx when nested context
    322 pl_module._current_fx_name = prev_fx_name

File [~\pythonEnv\lib\site-packages\pytorch_lightning\strategies\strategy.py:411](http://localhost:8888/lab/tree/src/few-shot-learning/~/pythonEnv/lib/site-packages/pytorch_lightning/strategies/strategy.py#line=410), in Strategy.validation_step(self, *args, **kwargs)
    409 if self.model != self.lightning_module:
    410     return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
--> 411 return self.lightning_module.validation_step(*args, **kwargs)

Cell In[156], line 111, in MAMLModel.validation_step(self, batch, batch_idx)
    109 support_images, support_targets = support_images.to(self.device), MAMLModel.get_targets_from_labels(support_labels).to(self.device)
    110 query_images, query_targets = query_images.to(self.device), MAMLModel.get_targets_from_labels(query_labels).to(self.device)
--> 111 query_outputs = self.step(support_images, support_targets, query_images, query_targets)
    112 self.validation_step_outputs.append({'preds': query_outputs, 'targets': query_targets})

Cell In[156], line 87, in MAMLModel.step(self, *args)
     86 def step(self, *args):
---> 87     return self.step_higher(*args)

Cell In[156], line 46, in MAMLModel.step_higher(self, support_images, support_targets, query_images, query_targets)
     44     support_loss += self.hparams.inner_reg * sum(p.pow(2).mean() for n, p in fmodel.named_parameters() if p.requires_grad and 'weight' in n)
     45     self.log('val_support_loss', support_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
---> 46     diffopt.step(loss)
     49 # Compute query loss
     50 query_outputs = fmodel(query_images)

File [~\pythonEnv\lib\site-packages\higher\optim.py:229](http://localhost:8888/lab/tree/src/few-shot-learning/~/pythonEnv/lib/site-packages/higher/optim.py#line=228), in DifferentiableOptimizer.step(self, loss, params, override, grad_callback, **kwargs)
    223 # This allows us to gracefully deal with cases where params are frozen.
    224 grad_targets = [
    225     p if p.requires_grad else _torch.tensor([], requires_grad=True)
    226     for p in params
    227 ]
--> 229 all_grads = _torch.autograd.grad(
    230     loss,
    231     grad_targets,
    232     create_graph=self._track_higher_grads,
    233     allow_unused=True  # boo
    234 )
    236 if grad_callback is not None:
    237     all_grads = grad_callback(all_grads)

File [~\pythonEnv\lib\site-packages\torch\autograd\__init__.py:394](http://localhost:8888/lab/tree/src/few-shot-learning/~/pythonEnv/lib/site-packages/torch/autograd/__init__.py#line=393), in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused, is_grads_batched, materialize_grads)
    390     result = _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)(
    391         grad_outputs_
    392     )
    393 else:
--> 394     result = Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    395         t_outputs,
    396         grad_outputs_,
    397         retain_graph,
    398         create_graph,
    399         t_inputs,
    400         allow_unused,
    401         accumulate_grad=False,
    402     )  # Calls into the C++ engine to run the backward pass
    403 if materialize_grads:
    404     result = tuple(
    405         output
    406         if output is not None
    407         else torch.zeros_like(input, requires_grad=True)
    408         for (output, input) in zip(result, t_inputs)
    409     )

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Environment

Current environment
#- PyTorch Lightning Version: 2.4.0
#- PyTorch Version: 2.1.1+cpu

More info

No response

bpfrd avatar Aug 08 '24 15:08 bpfrd