torchfunc icon indicating copy to clipboard operation
torchfunc copied to clipboard

Recorders only record activations of one sample not the entire batch

Open joshuachough opened this issue 4 years ago • 2 comments

I pass an entire batch through the model yet when I check the recorder's data, it only shows the activations for one (i think the last) sample.

An abbreviated version of my code:

recorder = torchfunc.hooks.recorders.ForwardOutput()
recorder.modules(model, types=(nn.Conv2d))\

with torch.no_grad():
    model.eval()
    
    for batch_idx, (data, target) in enumerate(testloader):
        print('data:', data.shape)
        output = model(data)
        break

activations = recorder.data
print('\nactivations:', len(activations))
print('subrecorders:', [subrecorder[0].shape for subrecorder in activations])
print('subrecorder {} samples: {}'.format(0, recorder.samples(0)))

output

data: torch.Size([32, 3, 64, 64])
activations: 10
subrecorders: [torch.Size([64, 64, 64]), torch.Size([64, 64, 64]), torch.Size([128, 32, 32]), torch.Size([128, 32, 32]), torch.Size([256, 16, 16]), torch.Size([256, 16, 16]), torch.Size([256, 16, 16]), torch.Size([1024, 16, 16]), torch.Size([1024, 16, 16]), torch.Size([21, 16, 16])]
subrecorder 0 samples: 1

joshuachough avatar Feb 02 '21 22:02 joshuachough

Just tested it again by running four batches instead of one, and the subrecorder's number of samples increased to 4. This confirms my theory that it's only getting one of the samples per batch. It also seems that each subrecorder is a list of lists of each image's activations.

print('subrecorder len:', [len(subrecorder) for subrecorder in activations])
print('subrecorder {} samples: {}'.format(0, recorder.samples(0)))

output

subrecorder len: [4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
subrecorder 0 samples: 4

joshuachough avatar Feb 02 '21 23:02 joshuachough

think i found the reason for why this is happening. in recorders.py, in the class _Hook's _call function, only the first sample of to_record (which are the output activations) is appended to self.data.

if self.index >= len(self.data):
    self.data.append(to_record[0])
    if reduction is None:
        self.data[-1] = [self.data[-1]]
else:
    if reduction is not None:
        self.data[self.index] = reduction(
            self.data[self.index], to_record[0]
        )
    else:
        self.data[self.index].append(to_record[0])

changing it to

if self.index >= len(self.data):
    self.data.append(to_record)
    if reduction is None:
        self.data[-1] = [self.data[-1]]
else:
    if reduction is not None:
        self.data[self.index] = reduction(
            self.data[self.index], to_record[0]
        )
    else:
        self.data[self.index].append(to_record)

fixed the problem

joshuachough avatar Feb 02 '21 23:02 joshuachough