zennit-crp icon indicating copy to clipboard operation
zennit-crp copied to clipboard

run_distributed method does not consider batch size

Open maxdreyer opened this issue 1 year ago • 0 comments

Hi @rachtibat,

the run_distributed method of the FeatureVisualization class does not take into account the actual batch_size for the multi-target case.

Maybe include something like:

if n_samples > batch_size:
    batches_ = math.ceil(len(conditions) / batch_size)
else:
    batches_ = 1

for b_ in range(batches_):
    data_broadcast_ = data_broadcast[b_ * batch_size: (b_ + 1) * batch_size]
    # print(len(conditions), len(data_broadcast_))
    conditions_ = conditions[b_ * batch_size: (b_ + 1) * batch_size]
    # dict_inputs is linked to FeatHooks
    dict_inputs["sample_indices"] = sample_indices[b_ * batch_size: (b_ + 1) * batch_size]
    dict_inputs["targets"] = targets[b_ * batch_size: (b_ + 1) * batch_size]

# composites are already registered before
    self.attribution(data_broadcast_, conditions_, None, exclude_parallel=False)

This would fix some GPU memory issue of mine.

Best, Max

maxdreyer avatar Apr 27 '23 13:04 maxdreyer