Wandb + fastai + RCNN error: patched solution
While training an RCNN learner with fastai and the WandbCallback, the following exception is raised:
TypeError Traceback (most recent call last)
~/miniconda3/envs/wandb/lib/python3.8/site-packages/fastai/learner.py in _with_events(self, f, event_type, ex, final)
--> 156 try: self(f'before_{event_type}') ;f()
157 except ex: self(f'after_cancel_{event_type}')
~/miniconda3/envs/wandb/lib/python3.8/site-packages/fastai/learner.py in _do_one_batch(self)
--> 167 self.pred = self.model(*self.xb)
168
~/miniconda3/envs/wandb/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
721 else:
--> 722 result = self.forward(*input, **kwargs)
723 for hook in itertools.chain(
~/miniconda3/envs/wandb/lib/python3.8/site-packages/torchvision/models/detection/generalized_rcnn.py in forward(self, images, targets)
78
---> 79 images, targets = self.transform(images, targets)
80
~/miniconda3/envs/wandb/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
734 else:
--> 735 var = var[0]
736 grad_fn = var.grad_fn
TypeError: 'ImageList' object is not subscriptable
The issue is triggered by the fastai WandbCallback which invokes wandb.watch(self.learn.model, log=self.log).
This method from the wandb library is used to create hooks into the trained model to log relevant metrics during the forward and backward passes (losses, gradients, etc).
This behavior interferes with the first building block of the torchvision implementation of an RCNN model, the GeneralizedRCNNTransform
(transform): GeneralizedRCNNTransform(
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
Resize(min_size=(800,), max_size=1333, mode='bilinear')
)
which is tasked to preprocess the batch before passing it to the model's backbone, returning an ImageList object. While running the model's forward pass, the presence of wandb hooks triggers the execution of the following snippet from torch/nn/modules/module.py
if (len(self._backward_hooks) > 0) or (len(_global_backward_hooks) > 0):
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
which throws the exception.
The temporary solution we implemented (here) consists of monkey-patching the wandb.watch function, replacing it with a no-op, ONLY if an RCNN learner is being trained. This does not disrupt the standard W&B logging and prevents the training process to fail.