transformers
                                
                                 transformers copied to clipboard
                                
                                    transformers copied to clipboard
                            
                            
                            
                        Significant performance degradation with multi-GPU training on newer torch/transformers
System Info
# Env 1
- `Accelerate` version: 0.30.1
- Platform: Linux-5.15.0-1058-aws-x86_64-with-glibc2.31
- `accelerate` bash location: /home/ubuntu/miniconda3/envs/train/bin/accelerate
- Python version: 3.10.14
- Numpy version: 1.26.4
- PyTorch version (GPU?): 2.3.0+cu121 (True)
- Transformers version: 4.40.2
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- System RAM: 186.70 GB
- GPU type: NVIDIA A10G
- `Accelerate` default config:
        Not found
# Env 2
- `Accelerate` version: 0.20.3
- Platform: Linux-5.15.0-1058-aws-x86_64-with-glibc2.31
- Python version: 3.10.14
- Numpy version: 1.26.4
- PyTorch version (GPU?): 2.0.1+cu117 (True)
- Transformers version: 4.30.2
- PyTorch XPU available: False
- System RAM: 186.70 GB
- GPU type: NVIDIA A10G
- `Accelerate` default config:
        Not found
Information
- [ ] The official example scripts
- [X] My own modified scripts
Tasks
- [ ] One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainerscript in theexamplesfolder of thetransformersrepo (such asrun_no_trainer_glue.py)
- [X] My own task or dataset (give details below)
Reproduction
I am using a g5.12xlarge EC2 instance for this test but I observed this issue on other machines as well. This is just a minimum example to demonstrate the issue. In my actual usage, the degradation is even worse.
- Create env1and install:pip install transformers torch accelerate.
- Create env2and install:pip install transformers==4.30.2 torch==2.0.1 accelerate==0.20.3.
- Run the following script using torchrun --nproc-per-node=4 test.py.
from typing import Iterator
import torch
from transformers import T5ForConditionalGeneration, Trainer, TrainingArguments
from torch.utils.data import IterableDataset
class DummyDataset(IterableDataset):
    def __iter__(self) -> Iterator:
        while True:
            yield {
                "input_ids": torch.randint(4000, size=(512,)),
                "labels": torch.randint(4000, size=(64,)),
            }
if __name__ == "__main__":
    model = T5ForConditionalGeneration.from_pretrained("google/t5-efficient-small")
    dataset = DummyDataset()
    training_args = TrainingArguments(
        output_dir="./output/",
        max_steps=1000_000,
        per_device_train_batch_size=16,
    )
    trainer = Trainer(model=model, train_dataset=dataset, args=training_args)
    trainer.train()
Observations
- On env1GPU0 utilization keeps fluctuating and the estimated training time is shown as ~82hrs.
- On env2all GPUs have utilization maxed out and the estimated training time is shown as ~66hrs.
Expected behavior
Both environments should have similar training time.
Accelerate isn't the issue.
Timings based on my 2x4090:
Assume 0.x are accelerate versions
On transformers v4.30.2:
- 0.30.1: ~28.5hrs
- 0.29.3: ~29hrs <- We fixed this
- 0.28.0: ~28.5hrs
- 0.21.0: (minimum for transformers): ~28.5hrs
On transformers 4.40.2:
- 0.30.1: ~29.5hrs
- 0.29.3: ~29.5hrs
- 0.28.0: ~29.5hrs
On transformers 4.30.2:
- 0.20.3: ~28.5hrs
- 0.30.1: ~28.5hrs
So you can see that this issue might involve the trainer, however I didn't actually see any changes here as you can tell.
In a last ditch effort:
- torch==2.0.1,- accelerate==0.30.1,- transformers==4.30.2: 28.5hrs
- torch==2.0.1,- accelerate==0.20.3,- transformers==4.30.2: 28.5hrs
- torch==2.0.1,- accelerate==0.30.1,- transformers==4.40.2: 29.5hrs
Now we are seeing issues from transformers instead.
Narrowing it down further (assuming same torch and accelerate):
- transformers==4.39.3: 29.5hrs
- ...
- transformers==4.34.1: 29.5hrs
- ...
- transformers==4.32.1: 29.5hrs
- transformers==4.31.0: 28.5hrs
So the issue stems from transformers 4.32.1 + torch 2.0.1
I'm not sure it's worth us fixing, since updating your torch version will solve this problem.
Is there a specific use-case for needing torch 2.0.1 and you can't use a later version?
Also: One thing I found could affect it by a number of an hr was the temp my GPU was at. If it was cool/a cold start it could be an hr slower. There's lots of variables at play here and what exactly is the cause of your issue I'm unsure of, even after thorough looking
@muellerzr Thanks a lot for checking. All your tests seem to be in the same ballpark, so I don't think this really reproduces the issue. Also note that the performance seems to be degrading with more number of GPUs, so 2x4090 may not be enough to reproduce it. I can run some more tests on my end, if you have suggestions.
Regarding torch version: Unfortunately, the problem (as I have described above) is that recent torch/transformers versions are actually the ones that are slow. Therefore, I cannot just upgrade them to fix the problem. In fact, I actually upgraded the libraries when I noticed the problem.
Regarding temperature: I don't think that's the issue. I have tested this on multiple machines and multiple times. Switching the env changes the runtime significantly, so I doubt that temperature is to blame here. Also, the issue here is not ~1 hr worse performance but by a factor of 2 in many cases (~6hrs vs ~12hrs or 66hrs vs 82hrs as in my example above).
I’ll see if I can get access to an 8-node system to debug.
BUT that would mean we’re hitting a ton of unnecessary distributed communications somewhere along the line (since it was working before).
I ran some tests again (all done in fresh envs):
- torch==2.0.1 transformers==4.30.2 accelerate==0.20.3: shows ~66hrs.
- torch==2.0.1 transformers==4.40.2 accelerate==0.30.1: shows ~67hrs.
- torch==2.1.2 transformers==4.40.2 accelerate==0.30.1: shows ~67hrs.
- torch==2.2.2 transformers==4.40.2 accelerate==0.30.1: shows ~82hrs.
- torch==2.3.0 transformers==4.40.2 accelerate==0.30.1: shows ~82hrs.
It looks like recent transformers/accelerate versions are only slightly worse when used with torch==2.0.1/torch==2.1.2 but significantly worse on torch==2.2.2 and later. Any idea on what could be going on? Is this more of a torch issue?
Let me dig today and see if we have any torch 2.2+ import checks that could differ.
Do we have any update on this @muellerzr?
I thought maybe this has something to do with iterable vs map style datasets, so I did the following test but it's the same story.
from typing import Iterator
import torch
from transformers import T5ForConditionalGeneration, Trainer, TrainingArguments
from torch.utils.data import IterableDataset, Dataset
class IterableDummyDataset(IterableDataset):
    def __iter__(self) -> Iterator:
        while True:
            yield {
                "input_ids": torch.randint(4000, size=(512,)),
                "labels": torch.randint(4000, size=(64,)),
            }
class MapDummyDataset(Dataset):
    def __len__(self):
        return 1000
    def __getitem__(self, i):
        return {
            "input_ids": torch.randint(4000, size=(512,)),
            "labels": torch.randint(4000, size=(64,)),
        }
if __name__ == "__main__":
    model = T5ForConditionalGeneration.from_pretrained("google/t5-efficient-small")
    dataset = MapDummyDataset()
    training_args = TrainingArguments(
        output_dir="./output/",
        max_steps=1000_000,
        per_device_train_batch_size=16,
    )
    trainer = Trainer(model=model, train_dataset=dataset, args=training_args)
    trainer.train()
This really looks like a torch issue. I have opened an issue here: https://github.com/pytorch/pytorch/issues/127077
Thanks for finding and reproducing in torch! Will keep a close eye on it 🤗
@abdulfatir BTW, they mentioned the aligning of CPU cores, modify the following in your accelerate config and it'll do that:
enable_cpu_affinity: true
In your config, or pass it in as --enable_cpu_affinity.
Let me know if that does anything please 🙏
Thanks @muellerzr!
I normally use torchrun for my experiments. Here's what I did (please correct me if I am wrong).
- Create accelerate configwith the following values.
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
enable_cpu_affinity: true
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
- Run the script at the top of this issue with accelerate launch --config_file /home/ubuntu/.cache/huggingface/accelerate/default_config.yaml test-trans.py.
- The performance is still much worse (more than 2x) than older versions.
Oh boy, okay. Well, that was a thought 😢
CC @stas00 if I misunderstood anything? (See the nccl issue)
Thanks a lot for your active help @muellerzr!
The g5 instance where I am facing this issue is a popular EC2 instance, so it would be great if we can find some way to resolve this. What I find confusing is that I am facing a similar issue with my codebase on a p4d instance (8 x A100) which is very commonly used for LLM training but I can't reproduce it with a minimal example like the one at the top. If I find a minimal example, I will post here. The issue is same: 1.5x - 2x faster training on older versions of libraries.
I'm not sure how --enable_cpu_affinity can help with this issue and when it does the improvement is very small.
But I see that there an indepth discussion of this particular hardware happening here https://github.com/NVIDIA/nccl/issues/1298 by @abdulfatir.
I'd probably recommend to take that nccl discussion to AWS support next - @muellerzr, if @abdulfatir doesn't have a proper support, perhaps you could try to get some eyes from AWS via HF's AWS support channel?
While I was still at HF and was using AWS we had all kinds of incredible network performance degradations when Hyper-Threads were enabled and also when the wrong non-EFA enabled libnccl library was getting loaded (but nccl has changed since then). EFA (and other non-standard network stacks) often have to be configured meticulously for them to work correctly - one misconfigured env var and off goes the performance.
Enabled HT in SLURM was giving a 4x drop in performance w/ EFA v1.
@muellerzr I dug deeper into why I am seeing a slowdown on a p4d instance as well. I have narrowed things down to the fact that I see a slowdown on 4.31.0 but not on versions below that (4.30.2). One major change I notice in the diff is that 4.31.0 relegates DDP logic for iterable datasets from inside transformers using IterableDatasetShard to self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)). Any thoughts on why this could lead to slowdown? I haven't really dug into accelerate.prepare yet.
Okay, I dug deeper and found a piece of code confusing (I may just be sleepy, so please excuse me):
https://github.com/huggingface/accelerate/blob/v0.20.3/src/accelerate/data_loader.py#L697-L725
In L697, dispatch_batches is set to True for iterable datasets.
Then in L721, the if condition is only True when dispatch_batches is False but inside in L722 there is a check if isinstance(new_dataset, IterableDataset). This would never be True in the default case (dispatch_batches=None), right?
I can confirm that there's a significant speedup for my training code (from ~15hr to ~10hr) when I monkey patch the accelerator to have dispatch_batches=False. That said, the speed I get in 4.30.2 is ~8.5hr, so performance is still a bit worse.
@muellerzr any thoughts on my comments above?
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.