flair
flair copied to clipboard
Draft: Add multi gpu support
Hi! This is a draft PR that adds multi gpu support. @alanakbik and others: would you be interested in incorporating something like this? The core functionality is working and I've pasted a short script below demonstrating its usage. I get a near-linear speed increase -- for 1 epoch it took: 16 cpus --> 368s, 1 gpu --> 32s, 4 gpus --> 8s when when running on an AWS g5.12xlarge instance with 4 A10 GPUs.
There's a related issue here, and a past PR that never ended up merging.
The approach:
- This PR uses raw pytorch's
DistributedDataParallelrather than another package like fabric, accelerate, or deepspeed. This gives more control and visibility into exactly what's happening and avoids needing to integrate another large pytorch project's design on how to handle e.g. AMP. However, it leaves more to be handled in flair, such as multi-node / TPUs etc. I'm open to discussing/implementing other approaches if you have preferences. - In order to use multiple GPUs, users would call a
launch_distributedmechanism. This means 1) user code will be running num_gpus times which can be unintuitive and 2) existing flair scripts won't automatically use multi-gpus without refactoring. I think a simpler approach may be possible by spawning processes insideTrainer.train_custom. However, I ran into problems doing it this way (e.g.TransformerEmbeddingsandPluggable._event_queuewould not serialize correctly), and many multi-gpu projects involve this kind of complexity. I think this PR is still a step toward that better future though, and existing CPU/single-gpu usage is unchanged.
There are still TODOs. For example, the logging inside .train_custom prints out multiple times (once for each process/gpu). If you connect with the approach, I can add new commits fixing this by adding statements like if is_main_process(): or torch.distributed.gather_object to aggregate metrics across processes, similar to what's done for the eval steps in this PR.
Example usage:
import flair
from flair.data import Sentence
from flair.datasets import TREC_6
from flair.distributed_utils import launch_distributed
from flair.embeddings import TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
def example(max_epochs):
corpus = TREC_6()
label_type = "question_class"
label_dictionary = corpus.make_label_dictionary(label_type)
embeddings = TransformerDocumentEmbeddings(model="distilbert-base-uncased")
model = TextClassifier(embeddings, label_type, label_dictionary=label_dictionary)
trainer = ModelTrainer(model, corpus)
trainer.fine_tune("./tmp", max_epochs=max_epochs)
model.predict(Sentence("Hello, world!"))
if __name__ == "__main__":
mode = "multi_gpu"
epochs = 2
if mode == "multi_gpu":
launch_distributed(example, epochs)
elif mode == "single_gpu":
example(epochs)
elif mode == "cpu":
flair.device = "cpu"
example(epochs)
print("Done")
Hello @jeffpicard this is awesome, thanks for the PR!
@helpmefindaname @HallerPatrick can you take a look?
Hey @jeffpicard, thanks for the PR.
I tested your changes with different number of GPUs and can, more or less, reproduce your speedups!
I also like the approach of settings everything up in-process to isolate the distribution logic only for the training logic. For the logging, we could go simple with:
if flair.distributed:
self.model = DistributedModel(self.model, device_ids=[flair.device.index])
# Disable logging in distributed mode for all but the main process
log.disabled = not is_main_process()
Here some points from my side:
-
I am a little suspicious about the
DistributedModelwrapper, where we can now arbitrarily update theDistributedDataParallelmodel without knowing if it effects any distributed-logic. I see the convenience of it. Maybe we can check every__getattr__and__setattr__call just to be on the save side here :P -
Model saving logic is still distributed. Easily fixable
-
How to handle "best model" logic after each epoch. Should we just naively test the main process model? I dont know if this is nitpicky...
On a side note, maybe we can also implement multi-gpu support for the LM trainer @alanakbik :)
Thank you!
Many thanks for the thoughtful review!
isolate the distribution logic only for the training logic
I'll look into distributing across processes inside the call to .train/.fine_tune rather than before. Some of the serialization issues (e.g. Pluggable._event_queue) should be solvable, I think.
log.disabled = not is_main_process()
Ah, great idea, thanks!
-
I felt the same way! Thanks for calling it out. The idea is inspired by other implementations like Lightning Fabric. I'll be more careful about the implementation.
-
I did add an
if is_main_process()toModel.save, but I can move that in front of all the calls tomodel.savein trainer to be less surprising. -
I believe testing the main process model should be fine since the models on each process/GPU should be the same. However, the data should also be the same. The dev
Datasetshould already be the same on each process, but iftrain_lossis used, that's calculated only for the fraction of data the given process handles. I'll trytorch.distributed.gather_object(train_loss)to average across all processes/gpus. This will also help for logging the training progress.
I'll follow up soon.
Hi @jeffpicard
thank you for creating this draft. Conceptionally, I think this is a good way to finally integrate multi-gpu training in flair.
I tested this on on 2 RTX A 4000, by increasing the mini_batch_chunk_size to be so large that all gpu-memory is used. And the mini_batch_size to be either the same (multi-gpu) or 2x (single-gpu) to have a fair comparision in terms of batch-updates.
Also, I used clearML for logging.
With that, I can comfirm the ~2x speed improvement for 2 gpus & that the metrics at the end are about the same (although slightly worse for multi-gpu).
I observed, that somehow the logging at multi-gpu is off by 1 epoch:
here you see, that there was no report for epoch 1, but a epoch 21 magically appeared. I am not sure why that is.
Also, since currently the non-main-processes also log values, I could observe the following:
Here, the non-mainprocess is ahead of the main process, due to not having to evaluate. I am not sure, if that is good, or if we should rather syncronize the processes at the end of each epoch.
Obviously splitting the evaluation would also be an option, but I think that would imply a lot of changes that make this PR more complicated.
I wonder how the plugins are impacted by the multi-gpu. Logger plugins should obviously only work on the main-process, while others, like the lr-scheduler plugins need to be run on every process.
Note: currently the lr-scheduler doesn't know that multi-gpu training uses a higher batch-size/less train steps:
using the AnnealOnPlateau schduler doesn't work, as the non-main-processes fail without eval metric.
Thanks for looking at this @helpmefindaname !
logging at multi-gpu is off by 1 epoch
Ahh, sorry about that. I think it's from the new call to .set_epoch(epoch) which was off by 1.
the non-mainprocess is ahead of the main process [...] we should rather syncronize the processes
DistributedDataParallel should be synchronizing every backward(), but there was a bug. I fixed it.
plugins
Thanks to @hallerpatrick's good idea to disable the logger on all but the main process, I've simplified the plugins to run on all processes. This makes the AnnealOnPlateau work. However, yes, if a plugin needs to synchronize information from all processes, it'll have to explicitly do that.
I ran into an unfortunate wrinkle -- I no longer see a speedup after the following bug fix: I noticed the gradients were not the same on each process/gpu for a given epoch_num and batch_no across all GPUs, like they should be. I think this is because pytorch's synchronization implementation relies on hooks that get called when you __call__ a model rather than just use forward_loss. Changing the Trainer:
loss, datapoint_count = self.model.forward_loss(batch_step)
# becomes
loss, datapoint_count = self.model(batch_step)
fixes the gradients, but makes multiple GPUs a bit slower than a single GPU. Any idea what could be going on that's making it slower?
Any idea what could be going on that's making it slower?
Aha, with a bigger batch size, multiple GPUs are faster again. There's a little overhead to synchronizing the gradients, so the bigger the batch size, the more the overhead can be amortized.
I've fixed most of what's mentioned above
- Process forking now happens inside
.trainso all users have to do is add themulti_gpu=Trueargument - The metrics logged during training are now averaged/summed from all GPUs rather than printing the rank=0 data
- Removed
DistributedModelwrapper
I'll push these changes up.
I'm still stuck on:
- What to do about
forwardvsforward_loss. In order to get the gradients to synchronize, pytorch relies on hooks run by__call__, which then invoke the special functionforward. flair's trainer relies onforward_loss. Which is potentially convenient because forward can just be redirected to forward_loss. But someModel's also useforward. One option is to refactor all models so that either all use forward or none use forward but that's complex ¯_(ツ)_/¯. - I need to make
TransformerEmbeddingswork with pickle. Currently gettingTypeError: DistilBertModel.__init__() got an unexpected keyword argument 'instance_parameters'.
Let me know if you have any thoughts on forward.
And here's an example of running it on the lastest commit
from flair.datasets import IMDB
from flair.embeddings import DocumentTFIDFEmbeddings, TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
if __name__ == "__main__":
corpus = IMDB()
corpus.downsample(0.01)
label_type = "sentiment"
label_dictionary = corpus.make_label_dictionary(label_type)
embeddings = DocumentTFIDFEmbeddings(train_dataset=corpus.train)
# embeddings = TransformerDocumentEmbeddings(model="distilbert-base-uncased") # serialization error
model = TextClassifier(embeddings, label_type, label_dictionary=label_dictionary)
trainer = ModelTrainer(model, corpus)
trainer.fine_tune("./tmp", max_epochs=1, mini_batch_size=16, multi_gpu=True)
What to do about forward vs forward_loss
Oh, this can be resolved without a big refactor by patching forward similar to what Fabric does here.
make TransformerEmbeddings work with pickle
I see there are other objects that can't pickle, like Sentences with spans. Since a lot of things might not work, I'll plan to go back to the launch_distributed approach from the example in the original description. Doing it that way doesn't require everything to be pickleable. There are a couple other options like invoking the script with torchrun scipt.py, or integrating Lightning Fabric, which has a launcher that doesn't require objects to pickle. If you connect more with either of those, or committing to making all objects pickleable, let me know.
I'll add a commit soon with these changes, which I hope can take this out of draft.
Done! @helpmefindaname and @HallerPatrick can you please take another look? This looks good to me. What do you think about merging?
- Plugins
- I added a property so that each plugin can set whether it gets run on all processes or just the main one. Defaults to true (seemed safer). AnnealingPlugin and LinearSchedulerPlugin are the only plugins that run on all processes.
-
lr-scheduler doesn't know that multi-gpu training uses a higher batch-size
- I modified the calculation to take this into account.
- Example script
- I modified the script in the top comment to reflect a minimal example of running this.
Thanks for running CI. I think the same error occurs on master. I think @helpmefindaname already fixed it in this commit but it hasn't merged yet. I've added that commit's diff to this branch, with a minor bug fix changing del to pop.
I think CI should pass now.
Hey @jeffpicard, sorry for the late replies.
So I am taking a look now. I am testing with your example:
from flair.datasets import IMDB
from flair.embeddings import DocumentTFIDFEmbeddings, TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
from flair.distributed_utils import launch_distributed
def main():
corpus = IMDB()
corpus.downsample(0.01)
label_type = "sentiment"
label_dictionary = corpus.make_label_dictionary(label_type)
embeddings = DocumentTFIDFEmbeddings(train_dataset=corpus.train)
# # embeddings = TransformerDocumentEmbeddings(model="distilbert-base-uncased") # serialization error
model = TextClassifier(embeddings, label_type, label_dictionary=label_dictionary)
trainer = ModelTrainer(model, corpus)
trainer.fine_tune("./tmp", max_epochs=1, mini_batch_size=16, multi_gpu=True)
if __name__ == "__main__":
launch_distributed(main)
I think in your earlier message you forgot to add launch_distributed. Is that right?
I also had problem running the example, because i tried running the script with torchrun/torch.distributed.launch, which is my fault. But we should definitely add documentation for multi GPU!
Then I ran into another problem, that the models that should be moved to different devices have different parameters:
Moving model on device: 0
TextClassifier(
(embeddings): DocumentTFIDFEmbeddings()
(decoder): Linear(in_features=8921, out_features=2, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
(locked_dropout): LockedDropout(p=0.0)
(word_dropout): WordDropout(p=0.0)
(loss_function): CrossEntropyLoss()
(weights): None
(weight_tensor) None
)
Moving model on device: 1
TextClassifier(
(embeddings): DocumentTFIDFEmbeddings()
(decoder): Linear(in_features=9539, out_features=2, bias=True)
(dropout): Dropout(p=0.0, inplace=False)
(locked_dropout): LockedDropout(p=0.0)
(word_dropout): WordDropout(p=0.0)
(loss_function): CrossEntropyLoss()
(weights): None
(weight_tensor) None
)
The decoder has different number of features. Due to starting multiple processes and the data processing yielding different results for the data processing?
Generally data preprocessing should only be done on the main process. So this works:
from flair.datasets import IMDB
from flair.embeddings import DocumentTFIDFEmbeddings, TransformerDocumentEmbeddings
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
from flair.distributed_utils import launch_distributed
def main():
corpus = IMDB()
corpus.downsample(0.01)
label_type = "sentiment"
label_dictionary = corpus.make_label_dictionary(label_type)
launch_distributed(train, corpus, label_type, label_dictionary)
def train(corpus, label_type, label_dictionary):
embeddings = DocumentTFIDFEmbeddings(train_dataset=corpus.train)
# # embeddings = TransformerDocumentEmbeddings(model="distilbert-base-uncased") # serialization error
model = TextClassifier(embeddings, label_type, label_dictionary=label_dictionary)
trainer = ModelTrainer(model, corpus)
trainer.fine_tune("./tmp", max_epochs=1, mini_batch_size=16, multi_gpu=True)
if __name__ == "__main__":
main()
I think in Hugginface all preprocessing operations are agnostic to this, but I am not sure.
I dont know if it makes sense to add some type of guards to all data processing operations (sounds like a lot of work) or just make it quite clear, that just adding multi_gpu=True to the trainer, will not work.
Any thoughts on that?
Hey @HallerPatrick. Thanks again, I really appreciate your continued thought on this and bearing with me.
Oops, sorry about that, I miscommunicated which script was the latest one to use. I just added a commit putting the example script into the repo under examples/multi_gpu along with some documentation in the README to explain what to know alongside adding multi_gpu=True.
add some type of guards
Great idea. I added a guard in .train_custom in the latest commit so that training will raise an exception if the corpus is different on different processes (which can then lead to other preprocessing being different). In the docs, I included your method of initializing the corpus before spawning. However, that method won't work e.g. for NER since Sentences with entity spans aren't serializable. So I also mentioned setting the random seed to solve this. The latest commit adds an __eq__ method to Sentence to enable the guard.
What do you think? (I'm not aware of outstanding issues)
I just rebased to pull in the type-checking updates to master and resolved conflicts. The guard wasn't working with non-serializable sentences so I added a commit fixing that (and removing the Sentence.__eq__ method since it's no longer needed).
How's this look?
Any thoughts on merging this @HallerPatrick @helpmefindaname?
@jeffpicard Thank you very much for your help here! Sorry it took so long :P
Are there any open issues related to this PR?
A big thank you for all your hard work in adding this @jeffpicard - this is a huge new feature that many people have been waiting for! Can't wait to see this in action :)
For sure! Thank you for building a framework that's so easy to love.