pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

Multi-gpu training on single node (SLURM): script freezes at the end of validation because each process is manually set to a different DataLoader

Open kfoynt opened this issue 1 year ago • 4 comments

Bug description

I am training my model on 4 GPUs on a single node on SLURM. I generate my data using prepare_data and load them using the setup. The problem is that my data is very large, around 90GB and I plan to generate more data later on. I am aware that lightning is allocating 90GB for each process, 4 in this case. This is a lot of RAM, and I would like to avoid this.

Using numpy memory mapping is not really an option because my data have the property that each sample is a 2D tensor and each sample has different number of rows. So packing them altogether and using numpy memmap sounds like a nightmare.

Since I have access to the world_size and global_rank of each process from the trainer I thought that I could pass these values to my LightningDataModule, break my data to 4 shards in prepare_data and then for each process I can load the corresponding shard in the setup function.

This means that each process will have its own DataLoader. See my script for my LightningDataModule below:

`class MyDataModule(pl.LightningDataModule): def init(self, max_number, emb_dim, world_size, local_rank, global_rank, node_rank): super().init()

    self.max_number = max_number
    self.emb_dim = emb_dim

    self.world_size = world_size
    self.local_rank = local_rank
    self.global_rank = global_rank
    self.node_rank = node_rank

def prepare_data(self):

    n = 8192*2
    max_list_size = 8
    how_many = 2

    for i in range(self.world_size):

        n_shard = int(n/self.world_size)
    
        print("Start creating train data", flush=True)
        data, _, _ = create_data(n_shard, max_list_size, self.max_number, self.emb_dim, how_many)
        print("Done creating data", flush=True)

        print("Save train data", flush=True)
        torch.save(data, 'data_shard_' + str(i) + '.pt')
        print("Done saving train data", flush=True)

    n_test = 16
    max_list_size_test = 100
    how_many = 1

    for i in range(self.world_size):

        n_shard = int(n_test/self.world_size)
    
        print("Start creating test data", flush=True)
        test_data, _, _ = create_data(n_shard, max_list_size_test, self.max_number, self.emb_dim, how_many)
        print("Done creating test data", flush=True)

        print("Save test data", flush=True)
        torch.save(test_data, 'test_data_shard_' + str(i) + '.pt')
        print("Done saving test data", flush=True)

def setup(self, stage: str):

    print("Entering setup", flush=True)

    if stage == "fit":

        print("Load train data", flush=True)
        data = torch.load('data_shard_' + str(self.global_rank) + '.pt')
        print("Done", flush=True)

        print("Creating dataset, train", flush=True)
        self.dataset = VariableLengthDataset(data)
        print("Done", flush=True)

        print("Load test data", flush=True)
        test_data = torch.load('test_data_shard_' + str(self.global_rank) + '.pt')
        print("Done", flush=True)

        print("Creating dataset, test", flush=True)
        self.test_dataset = VariableLengthDataset(test_data)
        print("Done", flush=True)

def train_dataloader(self):
    
    return DataLoader(self.dataset, batch_size=32, collate_fn=variable_length_collate_fn, num_workers=1, shuffle=False, pin_memory=True, persistent_workers=True)

def val_dataloader(self):
    
    return DataLoader(self.test_dataset, batch_size=32, collate_fn=variable_length_collate_fn, num_workers=1, shuffle=False, pin_memory=True, persistent_workers=True)`

I have disabled logging in the validation_step completely. In the training_step I set

self.log('loss', loss, on_step=True, on_epoch=True, batch_size=len(inputs), sync_dist=True, rank_zero_only=True)

My script freezes either at the beginning of the validation loader or at the end, depending on how I set the number of workers and if I use persistent_workers=True. In any case it freezes.

I have faced a similar issue before when I mistakenly had each process generate the whole data from scratch, which created a different dataset for each process. But I was still working with small data and I simply followed the example in the documentation on how to write my LightningDataModule and everything worked fine. However, now I really need the sharding due to the memory limitations. Which means that I need each process to have its own DataLoader to avoid this data duplication issue.

Any ideas here?

What version are you seeing the problem on?

v2.1

How to reproduce the bug

Hopefully I provided enough information in my description.

Error messages and logs

No error message, it just freezes.

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version: 2.1.3
#- PyTorch Version: 2.1
#- Python version: 3.11
#- OS (e.g., Linux): Linux

More info

No response

kfoynt avatar Jan 11 '24 16:01 kfoynt

I added a custom sampler and set the number of workers to 0 to make sure that each GPU process uses the correct indices, but it still freezes. Here is the sampler

`class SingleChunkDistributedSampler(DistributedSampler): def init(self, dataset, num_replicas=None, rank=None, shuffle=True): super(SingleChunkDistributedSampler, self).init(dataset, num_replicas, rank, shuffle) self.chunk_size = len(dataset)

def __iter__(self):
    # Determine the indices for the current process
    indices = list(range(self.chunk_size))

    # Shuffle indices within each chunk if shuffle is enabled
    if self.shuffle:
        indices = self.shuffle_within_chunk(indices)

    # Ensure that each process only uses its chunk
    self.total_size = len(indices)
    self.num_samples = self.total_size

    # Set the seed for deterministic shuffling across processes
    self.set_epoch(self.epoch)

    return iter(indices)

def shuffle_within_chunk(self, indices):
    # Shuffle indices within each chunk
    chunk_start = 0
    chunk_end = self.chunk_size
    indices_within_chunk = indices[chunk_start:chunk_end]
    indices_within_chunk = torch.randperm(len(indices_within_chunk)).tolist()
    indices[chunk_start:chunk_end] = indices_within_chunk
    return indices`

kfoynt avatar Jan 11 '24 19:01 kfoynt

I also set

--cpus-per-task=1

To make sure that there is no CPU multiprocessing going on. Still freezes.

kfoynt avatar Jan 11 '24 19:01 kfoynt

@kfoynt Since you implemented your custom sampler, please make sure that the length it returns is the same on each process. That is very important. For a sanity check, I think you should remove the custom sampler and see what happens.

On another note, regarding this you mentioned:

self.log('loss', loss, on_step=True, on_epoch=True, batch_size=len(inputs), sync_dist=True, rank_zero_only=True)

Please read this section of the docs. I doubt that you should set rank_zero_only=True there.

awaelchli avatar Jan 18 '24 13:01 awaelchli

@kfoynt Can you take a look at my suggestion? Do you have any new insights?

awaelchli avatar Jan 23 '24 13:01 awaelchli

@kfoynt Can you take a look at my replies?

awaelchli avatar Feb 05 '24 03:02 awaelchli

Hey, sorry. I was working on a conference deadline. I removed the custom sampler. It wasn't working. I think that padding to make sure that all samples have the same number of rows could work. At the same time I would like to avoid the case of having my model to predict so many zeros.

I solved the problem for now by simply buying 1TB more RAM :-).

kfoynt avatar Feb 05 '24 04:02 kfoynt

Ok, thanks for getting back. I'm glad you could find a work around. So it does look like this was due to the sampler not returning the same amount of items on each process, which is a requirement. I'm closing the issue for now.

awaelchli avatar Feb 11 '24 01:02 awaelchli

Ok, thanks for getting back. I'm glad you could find a work around. So it does look like this was due to the sampler not returning the same amount of items on each process, which is a requirement. I'm closing the issue for now.

Hiiii ! I have a similar question. I am training with 8 gpus and I ensure that my sampler will generate the same number of samples, but I will generate different numbers of samples for training (which means each gpu has different batch size) in the collator function. In this case, my training script will get stuck; When I only use one gpu, it can work correctly. I really want to know will Lightning force the batch size on each card to be consistent?

ChangxinWang avatar Feb 22 '24 03:02 ChangxinWang

Ok, thanks for getting back. I'm glad you could find a work around. So it does look like this was due to the sampler not returning the same amount of items on each process, which is a requirement. I'm closing the issue for now.

Hiiii ! I have a similar question. I am training with 8 gpus and I ensure that my sampler will generate the same number of samples, but I will generate different numbers of samples for training (which means each gpu has different batch size) in the collator function. In this case, my training script will get stuck; When I only use one gpu, it can work correctly. I really want to know will Lightning force the batch size on each card to be consistent?

Additionly, now I use lightning 2.0.0. However, when I use lightning 1.9.5, I can make different batch size on each gpu and it work correctly (where I realize a maxtokensample for NLP task, which will generate different batch size according to the max tokens)

ChangxinWang avatar Feb 22 '24 03:02 ChangxinWang