data icon indicating copy to clipboard operation
data copied to clipboard

State_dict on dataset seems to be called more often than expected

Open gokulavasan opened this issue 8 months ago • 2 comments

🐛 Describe the bug

Consider the following code:

class DatasetStateIterable(torch.utils.data.IterableDataset, Stateful):
    def __init__(self, length):
        self.length = length

    def __iter__(self):
        return iter(list(range(self.length)))

    def state_dict(self):
		print("Calling state dict")
        return {"key": "value"}

    def load_state_dict(self, state_dict):
        pass

class TestSimple(TestCase):
    def test(self):
        dataset = DatasetStateIterable(100)
        dl = StatefulDataLoader(
            dataset=dataset,
			num_workers=1,
			snapshot_every_n_steps=10,
        )
        it = iter(dl)
		for _ in range(30):
			next(it)
        self.assertTrue(False)

Here snapshot frequency is set to every 10 steps. And the iteration is carried out for 30 steps. But here is the output on number of items (12 times) state_dict is called on the dataset

Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict
Calling state dict

Versions

Latest git commit - 82918dd

gokulavasan avatar Jun 10 '24 22:06 gokulavasan