data
data copied to clipboard
State_dict on dataset seems to be called more often than expected
🐛 Describe the bug
Consider the following code:
class DatasetStateIterable(torch.utils.data.IterableDataset, Stateful):
def __init__(self, length):
self.length = length
def __iter__(self):
return iter(list(range(self.length)))
def state_dict(self):
print("Calling state dict")
return {"key": "value"}
def load_state_dict(self, state_dict):
pass
class TestSimple(TestCase):
def test(self):
dataset = DatasetStateIterable(100)
dl = StatefulDataLoader(
dataset=dataset,
num_workers=1,
snapshot_every_n_steps=10,
)
it = iter(dl)
for _ in range(30):
next(it)
self.assertTrue(False)
Here snapshot frequency is set to every 10 steps. And the iteration is carried out for 30 steps. But here is the output on number of items (12 times) state_dict is called on the dataset
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Versions
Latest git commit - 82918dd