data icon indicating copy to clipboard operation
data copied to clipboard

Add an example using a library such as Hugging Face datasets

Open conceptofmind opened this issue 5 months ago • 2 comments

🚀 The feature

I think it would make sense to provide a "real-world" example for using the StatefulDataloader with a popular library such as Hugging Face datasets.

For example, the below example code uses IterableDatasets, StatefulDataloader, and Hugging Face streaming datasets together:

import os
from typing import Optional

import torch
import torch.distributed as dist
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
from torch.distributed import destroy_process_group
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data import IterableDataset
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import (
    AutoTokenizer,
    DataCollatorForLanguageModeling,
)

rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
assert torch.cuda.is_available()
device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))
torch.cuda.set_device(device)
dist.init_process_group(backend="nccl", device_id=device)
dist.barrier()


class TokenizedDataset(IterableDataset, Stateful):
    def __init__(
        self,
        path: str,
        tokenizer: AutoTokenizer,
        name: Optional[str] = None,
        split: str = "train",
        streaming: bool = True,
        max_length: int = 2048,
        ddp_rank: int = 0,
        ddp_world_size: int = 1,
    ):
        dataset = load_dataset(path, name, split=split, streaming=streaming)
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.train_dataset = split_dataset_by_node(
            dataset=dataset, rank=ddp_rank, world_size=ddp_world_size
        )

    def __iter__(self):
        for sample in iter(self.train_dataset):
            tokenized = self.tokenizer(
                sample["text"],
                padding="max_length",
                truncation=True,
                max_length=self.max_length,
                return_special_tokens_mask=True,
            )
            yield tokenized

    def load_state_dict(self, state_dict):
        assert "data" in state_dict
        self.train_dataset.load_state_dict(state_dict["data"])

    def state_dict(self):
        return {"data": self.train_dataset.state_dict()}


tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm_probability=0.5
)

tokenized_dataset = TokenizedDataset(
    path="Salesforce/wikitext",
    name="wikitext-2-v1",
    tokenizer=tokenizer,
    max_length=2048,
    ddp_rank=rank,
    ddp_world_size=world_size,
)

trainloader = StatefulDataLoader(
    dataset=tokenized_dataset,
    batch_size=64,
    num_workers=1,
    collate_fn=data_collator,
)

for step, batch in enumerate(trainloader):
    batch = {k: v.to(device) for k, v in batch.items()}
    print(step)
    print(batch)
    if step == 2:
        dataloader_state_dict = trainloader.state_dict()
        print(dataloader_state_dict)
        break

print(f"restart from checkpoint")
trainloader.load_state_dict(dataloader_state_dict)
for step, batch in enumerate(trainloader):
    batch = {k: v.to(device) for k, v in batch.items()}
    print(step)
    print(batch)
    if step == 2:
        dataloader_state_dict = trainloader.state_dict()
        print(dataloader_state_dict)
        break

destroy_process_group()

Motivation, pitch

If something like the above is both correct and useful, I would be happy to provide it as an example in the repository.

Alternatives

If not, just leaving it as a closed issue for others to reference in the future.

Additional context

I am willing to add more to this example as well if needed.

Thank you,

Enrico

conceptofmind avatar Aug 08 '25 06:08 conceptofmind

Hey, this is a very valid request. Thank YOU for raising it. We can think of simple scripts (like above) using COCO or Alpaca datasets. I know TorchTune would have lots of cool abstractions on HuggingFace and StatefulDataloader but maybe we can provide simpler ones as examples. cc. @ramanishsingh

divyanshk avatar Aug 08 '25 17:08 divyanshk

Hey, this is a very valid request. Thank YOU for raising it. We can think of simple scripts (like above) using COCO or Alpaca datasets. I know TorchTune would have lots of cool abstractions on HuggingFace and StatefulDataloader but maybe we can provide simpler ones as examples. cc. @ramanishsingh

Hello,

I would be happy to contribute a few different examples.

The above was just the minimal amount of code for training a BERT-like model with MLM, DDP, and Hugging Face streaming/iterable datasets using StatefulDataloader and torchrun.

I can also cover map datasets or VIT-like and GPT-like training with StatefulDataloader, too.

IMO, I think having the most minimal example possible is a good thing as it allows the user to hack on what they need with a simple baseline reference. This was the first thing I looked for when checking out the docs and this repository.

Thank you,

Enrico

conceptofmind avatar Aug 08 '25 18:08 conceptofmind