icevision icon indicating copy to clipboard operation
icevision copied to clipboard

Wandb + fastai + RCNN error: patched solution

Open FraPochetti opened this issue 5 years ago • 0 comments

While training an RCNN learner with fastai and the WandbCallback, the following exception is raised:

TypeError                                 Traceback (most recent call last)
~/miniconda3/envs/wandb/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
--> 156         try:       self(f'before_{event_type}')       ;f()
    157         except ex: self(f'after_cancel_{event_type}')

~/miniconda3/envs/wandb/lib/python3.8/site-packages/fastai/learner.py in _do_one_batch(self)
--> 167         self.pred = self.model(*self.xb)
    168 

~/miniconda3/envs/wandb/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(

~/miniconda3/envs/wandb/lib/python3.8/site-packages/torchvision/models/detection/generalized_rcnn.py in forward(self, images, targets)
     78 
---> 79         images, targets = self.transform(images, targets)
     80 

~/miniconda3/envs/wandb/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    734                 else:
--> 735                     var = var[0]
    736             grad_fn = var.grad_fn

TypeError: 'ImageList' object is not subscriptable

The issue is triggered by the fastai WandbCallback which invokes wandb.watch(self.learn.model, log=self.log). This method from the wandb library is used to create hooks into the trained model to log relevant metrics during the forward and backward passes (losses, gradients, etc). This behavior interferes with the first building block of the torchvision implementation of an RCNN model, the GeneralizedRCNNTransform

(transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )

which is tasked to preprocess the batch before passing it to the model's backbone, returning an ImageList object. While running the model's forward pass, the presence of wandb hooks triggers the execution of the following snippet from torch/nn/modules/module.py

        if (len(self._backward_hooks) > 0) or (len(_global_backward_hooks) > 0):
            var = result
            while not isinstance(var, torch.Tensor):
                if isinstance(var, dict):
                    var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
                else:
                    var = var[0]

which throws the exception. The temporary solution we implemented (here) consists of monkey-patching the wandb.watch function, replacing it with a no-op, ONLY if an RCNN learner is being trained. This does not disrupt the standard W&B logging and prevents the training process to fail.

FraPochetti avatar Nov 03 '20 18:11 FraPochetti