zennit-crp
zennit-crp copied to clipboard
run_distributed method does not consider batch size
Hi @rachtibat,
the run_distributed
method of the FeatureVisualization
class does not take into account the actual batch_size
for the multi-target case.
Maybe include something like:
if n_samples > batch_size:
batches_ = math.ceil(len(conditions) / batch_size)
else:
batches_ = 1
for b_ in range(batches_):
data_broadcast_ = data_broadcast[b_ * batch_size: (b_ + 1) * batch_size]
# print(len(conditions), len(data_broadcast_))
conditions_ = conditions[b_ * batch_size: (b_ + 1) * batch_size]
# dict_inputs is linked to FeatHooks
dict_inputs["sample_indices"] = sample_indices[b_ * batch_size: (b_ + 1) * batch_size]
dict_inputs["targets"] = targets[b_ * batch_size: (b_ + 1) * batch_size]
# composites are already registered before
self.attribution(data_broadcast_, conditions_, None, exclude_parallel=False)
This would fix some GPU memory issue of mine.
Best, Max