torchfunc
torchfunc copied to clipboard
Recorders only record activations of one sample not the entire batch
I pass an entire batch through the model yet when I check the recorder
's data, it only shows the activations for one (i think the last) sample.
An abbreviated version of my code:
recorder = torchfunc.hooks.recorders.ForwardOutput()
recorder.modules(model, types=(nn.Conv2d))\
with torch.no_grad():
model.eval()
for batch_idx, (data, target) in enumerate(testloader):
print('data:', data.shape)
output = model(data)
break
activations = recorder.data
print('\nactivations:', len(activations))
print('subrecorders:', [subrecorder[0].shape for subrecorder in activations])
print('subrecorder {} samples: {}'.format(0, recorder.samples(0)))
output
data: torch.Size([32, 3, 64, 64])
activations: 10
subrecorders: [torch.Size([64, 64, 64]), torch.Size([64, 64, 64]), torch.Size([128, 32, 32]), torch.Size([128, 32, 32]), torch.Size([256, 16, 16]), torch.Size([256, 16, 16]), torch.Size([256, 16, 16]), torch.Size([1024, 16, 16]), torch.Size([1024, 16, 16]), torch.Size([21, 16, 16])]
subrecorder 0 samples: 1
Just tested it again by running four batches instead of one, and the subrecorder's number of samples increased to 4. This confirms my theory that it's only getting one of the samples per batch. It also seems that each subrecorder is a list of lists of each image's activations.
print('subrecorder len:', [len(subrecorder) for subrecorder in activations])
print('subrecorder {} samples: {}'.format(0, recorder.samples(0)))
output
subrecorder len: [4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
subrecorder 0 samples: 4
think i found the reason for why this is happening. in recorders.py
, in the class _Hook
's _call
function, only the first sample of to_record
(which are the output activations) is appended to self.data
.
if self.index >= len(self.data):
self.data.append(to_record[0])
if reduction is None:
self.data[-1] = [self.data[-1]]
else:
if reduction is not None:
self.data[self.index] = reduction(
self.data[self.index], to_record[0]
)
else:
self.data[self.index].append(to_record[0])
changing it to
if self.index >= len(self.data):
self.data.append(to_record)
if reduction is None:
self.data[-1] = [self.data[-1]]
else:
if reduction is not None:
self.data[self.index] = reduction(
self.data[self.index], to_record[0]
)
else:
self.data[self.index].append(to_record)
fixed the problem