ipex-llm icon indicating copy to clipboard operation
ipex-llm copied to clipboard

[Orca] Code Design for Distributed PyTorch DataLoader on Ray

Open sgwhat opened this issue 1 year ago • 2 comments

Description

This issue mainly describes a new distributed data input, which could be used the same as the standard Pytorch Dataloader.

Motivation

  • Pure Ray pipeline needs a distributed data input, which could be seamless for uses to switch from a typical PyTorch program.
  • Ray Datasets is not user-friendly enough, users may need to learn how to use it.

API Usage and Code Design

Here is a very typical usage.

# create a typical PyTorch DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=32) 

estimator = Estimator.from_torch(model, optimizer, loss, metrics, config, backend="ray")

train_stats = estimator.fit(train_dataloader, epochs=2, batch_size=32)

For code design.

if isinstance(data, torch.utils.data.DataLoader):

    def assign_batch_worker():
        # assign batches to each worker here
        # 1. get a dict {worker_rank: batch_idx} like {0: {0, 1, 6}, 1: {2, 3, 7}, 2: {4, 5}};
        # 2. get a list of batch_idx_list according to workers rank [[0, 1, 6], [2, 3, 7], [4, 5]];
        # 3. get a list contains batches according to the list sequence we get in the 2nd step. 
        # For example, [[batch0, batch1, batch6], [batch2, batch3, batch7], [batch4, batch5]]
        return shards
    
    shards = assign_batch_worker()

    def make_data_creator(shard):
        def data_creator():
            return shard
        return data_creator

    remote_worker_stats = []
    for worker, shard in zip(self.remote_workers, shards):
        params["data_creator"] = make_data_creator(shard)
        stats = worker.train_epochs.remote(**params)
        remote_worker_stats.append(stats)

    worker_stats = ray.get(remote_worker_stats)

sgwhat avatar Jul 27 '22 01:07 sgwhat

@yushan111

sgwhat avatar Jul 28 '22 01:07 sgwhat

Does assign_batch_worker read all the data from the original PyToch DataLoader and then return a list of all these data? @hkvision

jason-dai avatar Jul 31 '22 00:07 jason-dai