flair icon indicating copy to clipboard operation
flair copied to clipboard

Draft: Add multi gpu support

Open jeffpicard opened this issue 1 year ago • 8 comments

Hi! This is a draft PR that adds multi gpu support. @alanakbik and others: would you be interested in incorporating something like this? The core functionality is working and I've pasted a short script below demonstrating its usage. I get a near-linear speed increase -- for 1 epoch it took: 16 cpus --> 368s, 1 gpu --> 32s, 4 gpus --> 8s when when running on an AWS g5.12xlarge instance with 4 A10 GPUs.

There's a related issue here, and a past PR that never ended up merging.

The approach:

  • This PR uses raw pytorch's DistributedDataParallel rather than another package like fabric, accelerate, or deepspeed. This gives more control and visibility into exactly what's happening and avoids needing to integrate another large pytorch project's design on how to handle e.g. AMP. However, it leaves more to be handled in flair, such as multi-node / TPUs etc. I'm open to discussing/implementing other approaches if you have preferences.
  • In order to use multiple GPUs, users would call a launch_distributed mechanism. This means 1) user code will be running num_gpus times which can be unintuitive and 2) existing flair scripts won't automatically use multi-gpus without refactoring. I think a simpler approach may be possible by spawning processes inside Trainer.train_custom. However, I ran into problems doing it this way (e.g. TransformerEmbeddings and Pluggable._event_queue would not serialize correctly), and many multi-gpu projects involve this kind of complexity. I think this PR is still a step toward that better future though, and existing CPU/single-gpu usage is unchanged.

There are still TODOs. For example, the logging inside .train_custom prints out multiple times (once for each process/gpu). If you connect with the approach, I can add new commits fixing this by adding statements like if is_main_process(): or torch.distributed.gather_object to aggregate metrics across processes, similar to what's done for the eval steps in this PR.

Example usage:

import flair
from flair.data import Sentence
from flair.datasets import TREC_6
from flair.distributed_utils import launch_distributed
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer


def example(max_epochs):
    corpus = TREC_6()
    label_type = "question_class"
    label_dictionary = corpus.make_label_dictionary(label_type)
    embeddings = TransformerDocumentEmbeddings(model="distilbert-base-uncased")
    model = TextClassifier(embeddings, label_type, label_dictionary=label_dictionary)
    trainer = ModelTrainer(model, corpus)
    trainer.fine_tune("./tmp", max_epochs=max_epochs)
    model.predict(Sentence("Hello, world!"))


if __name__ == "__main__":
    mode = "multi_gpu"
    epochs = 2
    if mode == "multi_gpu":
        launch_distributed(example, epochs)
    elif mode == "single_gpu":
        example(epochs)
    elif mode == "cpu":
        flair.device = "cpu"
        example(epochs)
    print("Done")

jeffpicard avatar Sep 24 '24 17:09 jeffpicard

Hello @jeffpicard this is awesome, thanks for the PR!

@helpmefindaname @HallerPatrick can you take a look?

alanakbik avatar Sep 25 '24 03:09 alanakbik

Hey @jeffpicard, thanks for the PR.

I tested your changes with different number of GPUs and can, more or less, reproduce your speedups!

I also like the approach of settings everything up in-process to isolate the distribution logic only for the training logic. For the logging, we could go simple with:

if flair.distributed:
   self.model = DistributedModel(self.model, device_ids=[flair.device.index])
       
   # Disable logging in distributed mode for all but the main process
   log.disabled = not is_main_process()    

Here some points from my side:

  1. I am a little suspicious about the DistributedModel wrapper, where we can now arbitrarily update the DistributedDataParallel model without knowing if it effects any distributed-logic. I see the convenience of it. Maybe we can check every __getattr__ and __setattr__ call just to be on the save side here :P

  2. Model saving logic is still distributed. Easily fixable

  3. How to handle "best model" logic after each epoch. Should we just naively test the main process model? I dont know if this is nitpicky...

On a side note, maybe we can also implement multi-gpu support for the LM trainer @alanakbik :)

Thank you!

HallerPatrick avatar Sep 25 '24 10:09 HallerPatrick

Many thanks for the thoughtful review!

isolate the distribution logic only for the training logic

I'll look into distributing across processes inside the call to .train/.fine_tune rather than before. Some of the serialization issues (e.g. Pluggable._event_queue) should be solvable, I think.

log.disabled = not is_main_process()

Ah, great idea, thanks!

  1. I felt the same way! Thanks for calling it out. The idea is inspired by other implementations like Lightning Fabric. I'll be more careful about the implementation.

  2. I did add an if is_main_process() to Model.save, but I can move that in front of all the calls to model.save in trainer to be less surprising.

  3. I believe testing the main process model should be fine since the models on each process/GPU should be the same. However, the data should also be the same. The dev Dataset should already be the same on each process, but if train_loss is used, that's calculated only for the fraction of data the given process handles. I'll try torch.distributed.gather_object(train_loss) to average across all processes/gpus. This will also help for logging the training progress.

I'll follow up soon.

jeffpicard avatar Sep 25 '24 16:09 jeffpicard

Hi @jeffpicard

thank you for creating this draft. Conceptionally, I think this is a good way to finally integrate multi-gpu training in flair.

I tested this on on 2 RTX A 4000, by increasing the mini_batch_chunk_size to be so large that all gpu-memory is used. And the mini_batch_size to be either the same (multi-gpu) or 2x (single-gpu) to have a fair comparision in terms of batch-updates. Also, I used clearML for logging. With that, I can comfirm the ~2x speed improvement for 2 gpus & that the metrics at the end are about the same (although slightly worse for multi-gpu).

I observed, that somehow the logging at multi-gpu is off by 1 epoch: image here you see, that there was no report for epoch 1, but a epoch 21 magically appeared. I am not sure why that is.

Also, since currently the non-main-processes also log values, I could observe the following: image Here, the non-mainprocess is ahead of the main process, due to not having to evaluate. I am not sure, if that is good, or if we should rather syncronize the processes at the end of each epoch. Obviously splitting the evaluation would also be an option, but I think that would imply a lot of changes that make this PR more complicated.

I wonder how the plugins are impacted by the multi-gpu. Logger plugins should obviously only work on the main-process, while others, like the lr-scheduler plugins need to be run on every process. Note: currently the lr-scheduler doesn't know that multi-gpu training uses a higher batch-size/less train steps: image

using the AnnealOnPlateau schduler doesn't work, as the non-main-processes fail without eval metric.

helpmefindaname avatar Oct 04 '24 17:10 helpmefindaname

Thanks for looking at this @helpmefindaname !

logging at multi-gpu is off by 1 epoch

Ahh, sorry about that. I think it's from the new call to .set_epoch(epoch) which was off by 1.

the non-mainprocess is ahead of the main process [...] we should rather syncronize the processes

DistributedDataParallel should be synchronizing every backward(), but there was a bug. I fixed it.

plugins

Thanks to @hallerpatrick's good idea to disable the logger on all but the main process, I've simplified the plugins to run on all processes. This makes the AnnealOnPlateau work. However, yes, if a plugin needs to synchronize information from all processes, it'll have to explicitly do that.


I ran into an unfortunate wrinkle -- I no longer see a speedup after the following bug fix: I noticed the gradients were not the same on each process/gpu for a given epoch_num and batch_no across all GPUs, like they should be. I think this is because pytorch's synchronization implementation relies on hooks that get called when you __call__ a model rather than just use forward_loss. Changing the Trainer:

loss, datapoint_count = self.model.forward_loss(batch_step)
# becomes
loss, datapoint_count = self.model(batch_step)

fixes the gradients, but makes multiple GPUs a bit slower than a single GPU. Any idea what could be going on that's making it slower?

jeffpicard avatar Oct 07 '24 12:10 jeffpicard

Any idea what could be going on that's making it slower?

Aha, with a bigger batch size, multiple GPUs are faster again. There's a little overhead to synchronizing the gradients, so the bigger the batch size, the more the overhead can be amortized.

I've fixed most of what's mentioned above

  • Process forking now happens inside .train so all users have to do is add the multi_gpu=True argument
  • The metrics logged during training are now averaged/summed from all GPUs rather than printing the rank=0 data
  • Removed DistributedModel wrapper

I'll push these changes up.


I'm still stuck on:

  • What to do about forward vs forward_loss. In order to get the gradients to synchronize, pytorch relies on hooks run by __call__, which then invoke the special function forward. flair's trainer relies on forward_loss. Which is potentially convenient because forward can just be redirected to forward_loss. But some Model's also use forward. One option is to refactor all models so that either all use forward or none use forward but that's complex ¯_(ツ)_/¯.
  • I need to make TransformerEmbeddings work with pickle. Currently getting TypeError: DistilBertModel.__init__() got an unexpected keyword argument 'instance_parameters'.

Let me know if you have any thoughts on forward.

jeffpicard avatar Oct 08 '24 11:10 jeffpicard

And here's an example of running it on the lastest commit

from flair.datasets import IMDB
from flair.embeddings import DocumentTFIDFEmbeddings, TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer

if __name__ == "__main__":
    corpus = IMDB()
    corpus.downsample(0.01)
    label_type = "sentiment"

    label_dictionary = corpus.make_label_dictionary(label_type)
    embeddings = DocumentTFIDFEmbeddings(train_dataset=corpus.train)
    # embeddings = TransformerDocumentEmbeddings(model="distilbert-base-uncased")  # serialization error
    model = TextClassifier(embeddings, label_type, label_dictionary=label_dictionary)
    trainer = ModelTrainer(model, corpus)
    trainer.fine_tune("./tmp", max_epochs=1, mini_batch_size=16, multi_gpu=True)

jeffpicard avatar Oct 08 '24 11:10 jeffpicard

What to do about forward vs forward_loss

Oh, this can be resolved without a big refactor by patching forward similar to what Fabric does here.

make TransformerEmbeddings work with pickle

I see there are other objects that can't pickle, like Sentences with spans. Since a lot of things might not work, I'll plan to go back to the launch_distributed approach from the example in the original description. Doing it that way doesn't require everything to be pickleable. There are a couple other options like invoking the script with torchrun scipt.py, or integrating Lightning Fabric, which has a launcher that doesn't require objects to pickle. If you connect more with either of those, or committing to making all objects pickleable, let me know.

I'll add a commit soon with these changes, which I hope can take this out of draft.

jeffpicard avatar Oct 18 '24 13:10 jeffpicard

Done! @helpmefindaname and @HallerPatrick can you please take another look? This looks good to me. What do you think about merging?

  • Plugins
    • I added a property so that each plugin can set whether it gets run on all processes or just the main one. Defaults to true (seemed safer). AnnealingPlugin and LinearSchedulerPlugin are the only plugins that run on all processes.
  • lr-scheduler doesn't know that multi-gpu training uses a higher batch-size

    • I modified the calculation to take this into account.
  • Example script
    • I modified the script in the top comment to reflect a minimal example of running this.

jeffpicard avatar Oct 25 '24 09:10 jeffpicard

Thanks for running CI. I think the same error occurs on master. I think @helpmefindaname already fixed it in this commit but it hasn't merged yet. I've added that commit's diff to this branch, with a minor bug fix changing del to pop.

I think CI should pass now.

jeffpicard avatar Oct 29 '24 01:10 jeffpicard

Hey @jeffpicard, sorry for the late replies.

So I am taking a look now. I am testing with your example:

from flair.datasets import IMDB
from flair.embeddings import DocumentTFIDFEmbeddings, TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
from flair.distributed_utils import launch_distributed


def main():
    corpus = IMDB()
    corpus.downsample(0.01)
    label_type = "sentiment"

    label_dictionary = corpus.make_label_dictionary(label_type)
    embeddings = DocumentTFIDFEmbeddings(train_dataset=corpus.train)
    # # embeddings = TransformerDocumentEmbeddings(model="distilbert-base-uncased")  # serialization error
    model = TextClassifier(embeddings, label_type, label_dictionary=label_dictionary)
    trainer = ModelTrainer(model, corpus)
    trainer.fine_tune("./tmp", max_epochs=1, mini_batch_size=16, multi_gpu=True)


if __name__ == "__main__":
    launch_distributed(main)

I think in your earlier message you forgot to add launch_distributed. Is that right?

I also had problem running the example, because i tried running the script with torchrun/torch.distributed.launch, which is my fault. But we should definitely add documentation for multi GPU!

Then I ran into another problem, that the models that should be moved to different devices have different parameters:

Moving model on device: 0
TextClassifier(
  (embeddings): DocumentTFIDFEmbeddings()
  (decoder): Linear(in_features=8921, out_features=2, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
  (locked_dropout): LockedDropout(p=0.0)
  (word_dropout): WordDropout(p=0.0)
  (loss_function): CrossEntropyLoss()
  (weights): None
  (weight_tensor) None
)
Moving model on device: 1
TextClassifier(
  (embeddings): DocumentTFIDFEmbeddings()
  (decoder): Linear(in_features=9539, out_features=2, bias=True)
  (dropout): Dropout(p=0.0, inplace=False)
  (locked_dropout): LockedDropout(p=0.0)
  (word_dropout): WordDropout(p=0.0)
  (loss_function): CrossEntropyLoss()
  (weights): None
  (weight_tensor) None
)

The decoder has different number of features. Due to starting multiple processes and the data processing yielding different results for the data processing?

Generally data preprocessing should only be done on the main process. So this works:

from flair.datasets import IMDB
from flair.embeddings import DocumentTFIDFEmbeddings, TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
from flair.distributed_utils import launch_distributed


def main():
    corpus = IMDB()
    corpus.downsample(0.01)
    label_type = "sentiment"
    label_dictionary = corpus.make_label_dictionary(label_type)
    launch_distributed(train, corpus, label_type, label_dictionary)

def train(corpus, label_type, label_dictionary):
    embeddings = DocumentTFIDFEmbeddings(train_dataset=corpus.train)
    # # embeddings = TransformerDocumentEmbeddings(model="distilbert-base-uncased")  # serialization error
    model = TextClassifier(embeddings, label_type, label_dictionary=label_dictionary)
    trainer = ModelTrainer(model, corpus)
    trainer.fine_tune("./tmp", max_epochs=1, mini_batch_size=16, multi_gpu=True)

if __name__ == "__main__":
    main()

I think in Hugginface all preprocessing operations are agnostic to this, but I am not sure.

I dont know if it makes sense to add some type of guards to all data processing operations (sounds like a lot of work) or just make it quite clear, that just adding multi_gpu=True to the trainer, will not work.

Any thoughts on that?

HallerPatrick avatar Nov 01 '24 09:11 HallerPatrick

Hey @HallerPatrick. Thanks again, I really appreciate your continued thought on this and bearing with me.

Oops, sorry about that, I miscommunicated which script was the latest one to use. I just added a commit putting the example script into the repo under examples/multi_gpu along with some documentation in the README to explain what to know alongside adding multi_gpu=True.

add some type of guards

Great idea. I added a guard in .train_custom in the latest commit so that training will raise an exception if the corpus is different on different processes (which can then lead to other preprocessing being different). In the docs, I included your method of initializing the corpus before spawning. However, that method won't work e.g. for NER since Sentences with entity spans aren't serializable. So I also mentioned setting the random seed to solve this. The latest commit adds an __eq__ method to Sentence to enable the guard.

What do you think? (I'm not aware of outstanding issues)

jeffpicard avatar Nov 08 '24 10:11 jeffpicard

I just rebased to pull in the type-checking updates to master and resolved conflicts. The guard wasn't working with non-serializable sentences so I added a commit fixing that (and removing the Sentence.__eq__ method since it's no longer needed).

How's this look?

jeffpicard avatar Nov 15 '24 05:11 jeffpicard

Any thoughts on merging this @HallerPatrick @helpmefindaname?

jeffpicard avatar Nov 22 '24 07:11 jeffpicard

@jeffpicard Thank you very much for your help here! Sorry it took so long :P

Are there any open issues related to this PR?

HallerPatrick avatar Nov 22 '24 17:11 HallerPatrick

A big thank you for all your hard work in adding this @jeffpicard - this is a huge new feature that many people have been waiting for! Can't wait to see this in action :)

alanakbik avatar Nov 22 '24 17:11 alanakbik

For sure! Thank you for building a framework that's so easy to love.

jeffpicard avatar Nov 22 '24 17:11 jeffpicard