ipex-llm
ipex-llm copied to clipboard
[Orca] Code Design for Distributed PyTorch DataLoader on Ray
Description
This issue mainly describes a new distributed data input, which could be used the same as the standard Pytorch Dataloader.
Motivation
- Pure Ray pipeline needs a distributed data input, which could be seamless for uses to switch from a typical PyTorch program.
- Ray Datasets is not user-friendly enough, users may need to learn how to use it.
API Usage and Code Design
Here is a very typical usage.
# create a typical PyTorch DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=32)
estimator = Estimator.from_torch(model, optimizer, loss, metrics, config, backend="ray")
train_stats = estimator.fit(train_dataloader, epochs=2, batch_size=32)
For code design.
if isinstance(data, torch.utils.data.DataLoader):
def assign_batch_worker():
# assign batches to each worker here
# 1. get a dict {worker_rank: batch_idx} like {0: {0, 1, 6}, 1: {2, 3, 7}, 2: {4, 5}};
# 2. get a list of batch_idx_list according to workers rank [[0, 1, 6], [2, 3, 7], [4, 5]];
# 3. get a list contains batches according to the list sequence we get in the 2nd step.
# For example, [[batch0, batch1, batch6], [batch2, batch3, batch7], [batch4, batch5]]
return shards
shards = assign_batch_worker()
def make_data_creator(shard):
def data_creator():
return shard
return data_creator
remote_worker_stats = []
for worker, shard in zip(self.remote_workers, shards):
params["data_creator"] = make_data_creator(shard)
stats = worker.train_epochs.remote(**params)
remote_worker_stats.append(stats)
worker_stats = ray.get(remote_worker_stats)
@yushan111
Does assign_batch_worker
read all the data from the original PyToch DataLoader and then return a list of all these data? @hkvision