data icon indicating copy to clipboard operation
data copied to clipboard

Are there any plans to optimize the fetcher_state in StatefulDataLoader?

Open howitry opened this issue 1 month ago • 4 comments

Since _IterableDatasetFetcher has no state attribute: https://github.com/pytorch/pytorch/blob/v2.6.0/torch/utils/data/_utils/fetch.py#L19, and the current fetcher_state:dataset_iter_state is None: https://github.com/meta-pytorch/data/blob/v0.11.0/torchdata/stateful_dataloader/worker.py#L277, could this cause prefetched data to be discarded during resume?

howitry avatar Dec 12 '25 09:12 howitry

Hi @howitry Yes, prefetched data is discarded when resumed.
Because we only store number of yielded samples. These prefetched batches are not saved in the checkpoint.

However, those batches are not lost. We just "discard" those prefetched batches, as in we do not store them in the checkpoint. However, we keep track of number of batches yielded, so the batches that were prefetched by the workers but not yielded by the dataloader, they are fetched again and yielded by the data loader. For non-stateful iterable datasets, the StatefulDataLoader fast forwards (skips the number of already yielded samples).

ramanishsingh avatar Dec 12 '25 18:12 ramanishsingh

we keep track of number of batches yielded, so the batches that were prefetched by the workers but not yielded by the dataloader, they are fetched again and yielded by the data loader.

According to your statement, although the StatefulDataLoader tracks the number of batches yielded, since these samples have already yielded in the dataset, the StatefulDataLoader cannot actually obtain the previously prefetched data when it prefetches again. Is it possible to save prefetched data in the checkpoint of a StatefulDataLoader to ensure that training can be resumed on completely correct data? I think this is important, especially for large-scale pre-training using hundreds of GPUs and with a prefetch_factor > 2, where the amount of prefetched data is actually quite large.

howitry avatar Dec 13 '25 02:12 howitry

@ramanishsingh can clarify on how we don't "lose data" and we resume on correct data.

But regarding saving actual data in ckpt, that is not feasible because data can be very large (even prefetched data), saving that into checkpoint can significantly slow down the checkpointing services used in the training frameworks. Also, at max we would have prefetch_factor * num_workers number of prefetched data samples. In most setups, this should be of the order of 10s of samples.

divyanshk avatar Dec 16 '25 00:12 divyanshk

@ramanishsingh can clarify on how we don't "lose data" and we resume on correct data.

But regarding saving actual data in ckpt, that is not feasible because data can be very large (even prefetched data), saving that into checkpoint can significantly slow down the checkpointing services used in the training frameworks. Also, at max we would have prefetch_factor * num_workers number of prefetched data samples. In most setups, this should be of the order of 10s of samples.

Please correct me if there are any misunderstandings: Because data is prefetched, the state of the iterable dataset is already updated to the prefetched position. When training resumes, the iterable dataset actually discards the prefetched but unused data. In other words, even if the StatefulDataLoader records the actual number of yielded batches, the iterable dataset's state is already at the previously prefetched position. When the StatefulDataLoader prefetches again, the iterable dataset can only yield from the position after the previously prefetched data. @divyanshk @ramanishsingh

howitry avatar Dec 16 '25 03:12 howitry

Hi @howitry

Trying to understand with the help of some examples.

  1. In case of non stateful iterable dataset
import torch
from torch.utils.data import IterableDataset, get_worker_info
from torchdata.stateful_dataloader import StatefulDataLoader

print("%"*10, "\n", "Without break", "\n", "%"*10)


class MyIterableDS(IterableDataset):
    
    def __init__(self, size=100):
        self.size = size
        
    def __iter__(self):  # iterate over samples
        worker_info = get_worker_info()
        num_workers = worker_info.num_workers
        worker_id = worker_info.id
        
        for i, s in enumerate(range(self.size)):
            if i % num_workers == worker_id:
                yield s
    
    def __len__(self):
        return self.size



dataset = MyIterableDS(10)
dataloader = StatefulDataLoader(dataset, batch_size=1, num_workers=2, prefetch_factor=2)


for batch in dataloader:
    print(batch)


print("%"*10, "\n", "With break", "\n", "%"*10)


dataset = MyIterableDS(10)
dataloader = StatefulDataLoader(dataset, batch_size=1, num_workers=2, prefetch_factor=2)
it = iter(dataloader)

break_point = 5
sd = None
batch_num = 0
while True:
    batch = next(it)
    print(batch)
    batch_num +=1
    sd = dataloader.state_dict()
    if batch_num == break_point:
        break
print("break happened")
dataloader = StatefulDataLoader(dataset, batch_size=1, num_workers=2, prefetch_factor=2)
dataloader.load_state_dict(sd)
for batch in dataloader:
    print(batch)

Output

%%%%%%%%%% 
 Without break 
 %%%%%%%%%%
tensor([0])
tensor([1])
tensor([2])
tensor([3])
tensor([4])
tensor([5])
tensor([6])
tensor([7])
tensor([8])
tensor([9])
%%%%%%%%%% 
 With break 
 %%%%%%%%%%
tensor([0])
tensor([1])
tensor([2])
tensor([3])
tensor([4])
break happened
W1216 211744.504 stateful_dataloader.py:1078] Neither dataset nor iter(dataset) defines state_dict/load_state_dict so we are naively fast-forwarding your dataset by 5 steps. For more efficient resumes, please implement `state_dict` and `load_state_dict` in your IterableDataset and/or iterator.
tensor([5])
tensor([6])
tensor([7])
tensor([8])
tensor([9])

See the warning about the fast forward

  1. In case of stateful dataset
import torch
from torch.utils.data import IterableDataset, get_worker_info
from torchdata.stateful_dataloader import StatefulDataLoader

print("%"*10, "\n", "Without break", "\n", "%"*10)


class MyStatefulIterableDS(IterableDataset):
    def __init__(self, size=100):
        self.size = size
        # Track position for each worker
        self.worker_states = {}
    def __iter__(self):
        worker_info = get_worker_info()
        num_workers = worker_info.num_workers
        worker_id = worker_info.id
        # Get the starting position for this worker
        start_pos = self.worker_states.get(worker_id, 0)
        # Only yield items assigned to this worker, starting from start_pos
        for i, s in enumerate(range(self.size)):
            if i % num_workers == worker_id and i >= start_pos:
                yield s
                # Update the worker's state after yielding
                self.worker_states[worker_id] = i + 1
    def __len__(self):
        return self.size
    def state_dict(self):
        # Save the current position for each worker
        return {'worker_states': self.worker_states.copy()}
    def load_state_dict(self, state):
        self.worker_states = state.get('worker_states', {}).copy()



dataset = MyStatefulIterableDS(10)
dataloader = StatefulDataLoader(dataset, batch_size=1, num_workers=2, prefetch_factor=2)


for batch in dataloader:
    print(batch)


print("%"*10, "\n", "With break", "\n", "%"*10)


dataset = MyStatefulIterableDS(10)
dataloader = StatefulDataLoader(dataset, batch_size=1, num_workers=2, prefetch_factor=2)
it = iter(dataloader)

break_point = 5
sd = None
batch_num = 0
while True:
    batch = next(it)
    print(batch)
    batch_num +=1
    sd = dataloader.state_dict()
    if batch_num == break_point:
        break
print("break happened")
dataloader = StatefulDataLoader(dataset, batch_size=1, num_workers=2, prefetch_factor=2)
dataloader.load_state_dict(sd)
for batch in dataloader:
    print(batch)

Output

%%%%%%%%%% 
 Without break 
 %%%%%%%%%%
tensor([0])
tensor([1])
tensor([2])
tensor([3])
tensor([4])
tensor([5])
tensor([6])
tensor([7])
tensor([8])
tensor([9])
%%%%%%%%%% 
 With break 
 %%%%%%%%%%
tensor([0])
tensor([1])
tensor([2])
tensor([3])
tensor([4])
break happened
tensor([3])
tensor([4])
tensor([5])
tensor([6])
tensor([7])
tensor([8])
tensor([9])

So, in both the cases we don't lose any samples.

ramanishsingh avatar Dec 17 '25 05:12 ramanishsingh