opacus icon indicating copy to clipboard operation
opacus copied to clipboard

`BatchSplittingSampler` return wrong length

Open dwahdany opened this issue 11 months ago • 2 comments

🐛 Bug

BatchSplittingSampler reports the length as

expected_batch_size = self.sampler.sample_rate * self.sampler.num_samples
return int(len(self.sampler) * (expected_batch_size / self.max_batch_size))

Converting the result simply to int leads to the resulted number of batches being one too low. Instead, we need to ceil the result first:

expected_batch_size = self.sampler.sample_rate * self.sampler.num_samples
return int(np.ceil(len(self.sampler) * (expected_batch_size / self.max_batch_size)))

Some libraries like pytorch lightning will skip the last batch if this is reported wrong, resulting in no actual step occuring at all.

dwahdany avatar Mar 22 '24 18:03 dwahdany