torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

Multiple GPU low performance

Open jetstudio-io opened this issue 1 year ago • 9 comments

Hello, I have an issue with multiple GPU performance.

  • I use the recipe lora_finetune_single_device with the config mini_lora_single_device.yaml on 6000ADA, I got ~5it/s
  • I use the recipe lora_finetune_distributed with the config mini_lora.yaml on 2 x 6000ADA, I got 1.5s/it The dataset that I used to fine-tune is HuggingFaceFW/fineweb-edu-score-2 How can I improve the performance in multiple GPU?

jetstudio-io avatar Oct 01 '24 21:10 jetstudio-io

Hi @jetstudio-io, thanks for the question. The it/s or sec/it metric is not a great indicator of performance here. Instead, I would check the logs for tokens per second to do a better comparison. For example:

$ cat /tmp/full-llama3.2-finetune/log_1727815865.txt
Step 1 | loss:2.7667150497436523 lr:2e-05 tokens_per_second_per_gpu:766.5005627443386

Or you can see it over time if you log with WandB and set log_memory_stats=True in your launch command.

Many factors can impact raw seconds/iteration, especially gradient accumulation, but it is not necessarily indicative of training convergence speed. That being said, there are still other ways to improve performance. You can check our documentation page on memory/perf features you can enable to get some ideas (cc @felipemello1): https://pytorch.org/torchtune/main/tutorials/memory_optimizations.html.

A very direct way to improve throughput is to enable packing in your dataset. If you are using the torchtune dataset builder functions, you can simply pass packed=True in your config or launch command.

RdoubleA avatar Oct 01 '24 21:10 RdoubleA

Like @RdoubleA said, the configh as "gradient_accumulation_steps: 16", which means that one step is actually 16.

Maybe try the following:

tune run lora_finetune_single_device --config phi3/mini_lora_single_device \
compile=True \
dataset.packed=True \
tokenizer.max_seq_len=2048 \
bsz=4 \
gradient_accumulation_steps=2 \
enable_activation_checkpointing = False \
log_every_n_steps=1 \
metric_logger=torchtune.training.metric_logging.WandBLogger \
log_peak_memory_stats=True

If you are running out of memory, set enable_activation_checkpointing=True Otherwise, increase bsz

You can see your memory in weights and biases website

Also use torchtune/pytorch nightlies for maximum performance: https://github.com/pytorch/torchtune#install-nightly-release

felipemello1 avatar Oct 01 '24 21:10 felipemello1

Thanks for yours advice, I'll try to test the token/s

jetstudio-io avatar Oct 02 '24 07:10 jetstudio-io

hi @RdoubleA, @felipemello1 I made the test of token/s/gpu:

  • 1 GPU (RTX 6000ADA) with tune run lora_finetune_single_device --config phi3/mini_lora_single_device: ~4100 token/s
  • 4 GPU (4 x 6000ADA) with tune run lora_finetune_distributed --config phi3/mini_lora_single_device: ~1100 token/s/gpu

Is there something wrong? How can I debug or increase the performance in multiple GPU training ?

Other thing, I used text_completion_dataset & when I used bzs=2 & parked=True, i got this error RuntimeError: stack expects each tensor to be equal size, but got [5] at entry 0 and [9] at entry 1 Do you have any idea to fix this error? I used the dataset HuggingFaceFW/fineweb-edu-score-2 (I download a parquet file to run on my local) The dataset initial:

        world_size, rank = training.get_world_size_and_rank()

        dataset = text_completion_dataset(
            tokenizer=self._tokenizer,
            source='parquet',
            column='content',
            data_dir=self._config.dataset.path,
            packed=True,
            split='train',
        )

        self._sampler = DistributedSampler(
            dataset, num_replicas=world_size, rank=rank, shuffle=True, seed=0
        )

        self._dataloader = DataLoader(
            dataset=dataset,
            batch_size=self._config.dataset.batch_size,
            sampler=self._sampler,
            # dropping last avoids shape issues with compile + flex attention
            drop_last=True
        )

Thanks,

jetstudio-io avatar Oct 12 '24 14:10 jetstudio-io

@jetstudio-io , regarding the difference in speed, it is probably due to some difference in config. You could post them here for me to take a look, but most obvious ones are:

compile (if true, its faster) enable_activation_checkpointing (if true, its slower, but less memory) enable_activation_offloading (if true, its slower, but less memory) fsdp_cpu_offload (if true, its slower, but less memory)

Regarding the dataset, i am not sure. Where exactly is this error raised? do you have tokenizer.max_seq_len defined? @RdoubleA , have you seen it before?

felipemello1 avatar Oct 14 '24 14:10 felipemello1

Hello, regarding the configuration, this is the config file that I used:

# Fine-tuning arguments
epochs: 1
max_steps_per_epoch: null
batch_size: 2
gradient_accumulation_steps: 8
compile: False
enable_activation_checkpointing: False
enable_activation_offloading: False
fsdp_cpu_offload: False

jetstudio-io avatar Oct 15 '24 06:10 jetstudio-io

For the dataset configuration:

# Dataset
dataset:
  _component_: torchtune.datasets.text_completion_dataset
  source: parquet
  split: train
  data_dir: ds/sample
  batch_size: 2
  packed: True

The error tracking log

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /root/torchtune/bin/tune:8 in <module>                                                           │
│                                                                                                  │
│   5 from torchtune._cli.tune import main                                                         │
│   6 if __name__ == '__main__':                                                                   │
│   7 │   sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])                         │
│ ❱ 8 │   sys.exit(main())                                                                         │
│   9                                                                                              │
│                                                                                                  │
│ /root/torchtune/.venv/lib/python3.10/site-packages/torchtune/_cli/tune.py:49 in main             │
│                                                                                                  │
│   46 def main():                                                                                 │
│   47 │   parser = TuneCLIParser()                                                                │
│   48 │   args = parser.parse_args()                                                              │
│ ❱ 49 │   parser.run(args)                                                                        │
│   50                                                                                             │
│   51                                                                                             │
│   52 if __name__ == "__main__":                                                                  │
│                                                                                                  │
│ /root/torchtune/.venv/lib/python3.10/site-packages/torchtune/_cli/tune.py:43 in run              │
│                                                                                                  │
│   40 │                                                                                           │
│   41 │   def run(self, args: argparse.Namespace) -> None:                                        │
│   42 │   │   """Execute CLI"""                                                                   │
│ ❱ 43 │   │   args.func(args)                                                                     │
│   44                                                                                             │
│   45                                                                                             │
│   46 def main():                                                                                 │
│                                                                                                  │
│ /root/torchtune/.venv/lib/python3.10/site-packages/torchtune/_cli/run.py:196 in _run_cmd.        │
│                                                                                                  │
│   193 │   │   │   │   )                                                                          │
│   194 │   │   │   self._run_distributed(args, is_builtin=is_builtin)                             │
│   195 │   │   else:                                                                              │
│ ❱ 196 │   │   │   self._run_single_device(args, is_builtin=is_builtin)                           │
│   197                                                                                            │
│                                                                                                  │
│ /root/torchtune/.venv/lib/python3.10/site-packages/torchtune/_cli/run.py:105 in                  │
│ _run_single_device.                                                                              │
│                                                                                                  │
│   102 │   │   │   runpy.run_path(str(args.recipe), run_name="__main__")                          │
│   103 │   │   else:                                                                              │
│   104 │   │   │   # custom recipes are specified as a relative module dot path                   │
│ ❱ 105 │   │   │   runpy.run_module(str(args.recipe), run_name="__main__")                        │
│   106 │                                                                                          │
│   107 │   def _is_distributed_args(self, args: argparse.Namespace):                              │
│   108 │   │   """Check if the user is trying to run a distributed recipe."""                     │
│                                                                                                  │
│ /root/torchtune/.venv/lib/python3.10/runpy.py:227 in run_module                                  │
│                                                                                                  │
│   224 │   │   return _run_module_code(code, init_globals, run_name, mod_spec)                    │
│   225 │   else:                                                                                  │
│   226 │   │   # Leave the sys module alone                                                       │
│ ❱ 227 │   │   return _run_code(code, {}, init_globals, run_name, mod_spec)                       │
│   228                                                                                            │
│   229 def _get_main_module_details(error=ImportError):                                           │
│   230 │   # Helper that gives a nicer error message when attempting to                           │
│                                                                                                  │
│ /root/torchtune/.venv/lib/python3.10/runpy.py:86 in _run_code                                    │
│                                                                                                  │
│    83 │   │   │   │   │      __loader__ = loader,                                                │
│    84 │   │   │   │   │      __package__ = pkg_name,                                             │
│    85 │   │   │   │   │      __spec__ = mod_spec)                                                │
│ ❱  86 │   exec(code, run_globals)                                                                │
│    87 │   return run_globals                                                                     │
│    88                                                                                            │
│    89 def _run_module_code(code, init_globals=None,                                              │
│                                                                                                  │
│ /root/torchtune/tune_lora_single_device.py:535 in <module>                                       │
│                                                                                                  │
│   532                                                                                            │
│   533                                                                                            │
│   534 if __name__ == "__main__":                                                                 │
│ ❱ 535 │   sys.exit(recipe_main())                                                                │
│   536                                                                                            │
│                                                                                                  │
│ /root/torchtune/.venv/lib/python3.10/site-packages/torchtune/config/_parse.py:99 in wrapper      │
│                                                                                                  │
│    96 │   │   yaml_args, cli_args = parser.parse_known_args()                                    │
│    97 │   │   conf = _merge_yaml_and_cli_args(yaml_args, cli_args)                               │
│    98 │   │                                                                                      │
│ ❱  99 │   │   sys.exit(recipe_main(conf))                                                        │
│   100 │                                                                                          │
│   101 │   return wrapper                                                                         │
│   102                                                                                            │
│                                                                                                  │
│ /root/torchtune/tune_lora_single_device.py:531 in recipe_main                                    │
│                                                                                                  │
│   528 │   recipe = LoRAFinetuneRecipeSingleDevice(cfg=cfg)                                       │
│   529 │   recipe.setup(cfg=cfg)                                                                  │
│   530 │   # recipe.save_checkpoint(epoch=0)                                                      │
│ ❱ 531 │   recipe.train()                                                                         │
│   532                                                                                            │
│   533                                                                                            │
│   534 if __name__ == "__main__":                                                                 │
│                                                                                                  │
│ /root/torchtune/tune_lora_single_device.py:486 in train                                          │
│                                                                                                  │
│   483 │   │   │   # self._dataloader.load_data(curr_epoch)                                       │
│   484 │   │   │   t0 = time.perf_counter()                                                       │
│   485 │   │   │   self._token_per_sec = 0                                                        │
│ ❱ 486 │   │   │   for idx, batch in enumerate(pbar := tqdm(self._dataloader)):                   │
│   487 │   │   │   │   loss = self._tune_lost_step(batch)                                         │
│   488 │   │   │   │   if loss is None:                                                           │
│   489 │   │   │   │   │   continue                                                               │
│                                                                                                  │
│ /root/torchtune/.venv/lib/python3.10/site-packages/tqdm/std.py:1181 in __iter__                  │
│                                                                                                  │
│   1178 │   │   time = self._time                                                                 │
│   1179 │   │                                                                                     │
│   1180 │   │   try:                                                                              │
│ ❱ 1181 │   │   │   for obj in iterable:                                                          │
│   1182 │   │   │   │   yield obj                                                                 │
│   1183 │   │   │   │   # Update and possibly print the progressbar.                              │
│   1184 │   │   │   │   # Note: does not call self.update(1) for speed optimisation.              │
│                                                                                                  │
│ /root/torchtune/.venv/lib/python3.10/site-packages/torch/utils/data/dataloader.py:630 in __next__│
│                                                                                                  │
│    627 │   │   │   if self._sampler_iter is None:                                                │
│    628 │   │   │   │   # TODO(https://github.com/pytorch/pytorch/issues/76750)                   │
│    629 │   │   │   │   self._reset()  # type: ignore[call-arg]                                   │
│ ❱  630 │   │   │   data = self._next_data()                                                      │
│    631 │   │   │   self._num_yielded += 1                                                        │
│    632 │   │   │   if self._dataset_kind == _DatasetKind.Iterable and \                          │
│    633 │   │   │   │   │   self._IterableDataset_len_called is not None and \                    │
│                                                                                                  │
│ /root/torchtune/.venv/lib/python3.10/site-packages/torch/utils/data/dataloader.py:673            │
│ in _next_data                                                                                    │
│                                                                                                  │
│    670 │                                                                                         │
│    671 │   def _next_data(self):                                                                 │
│    672 │   │   index = self._next_index()  # may raise StopIteration                             │
│ ❱  673 │   │   data = self._dataset_fetcher.fetch(index)  # may raise StopIteration              │
│    674 │   │   if self._pin_memory:                                                              │
│    675 │   │   │   data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)            │
│    676 │   │   return data                                                                       │
│                                                                                                  │
│ /root/torchtune/.venv/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:55           │
│ in fetch                                                                                         │
│                                                                                                  │
│   52 │   │   │   │   data = [self.dataset[idx] for idx in possibly_batched_index]                │
│   53 │   │   else:                                                                               │
│   54 │   │   │   data = self.dataset[possibly_batched_index]                                     │
│ ❱ 55 │   │   return self.collate_fn(data)                                                        │
│   56                                                                                             │
│                                                                                                  │
│ /root/torchtune/.venv/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py:317.       │
│ in default_collate                                                                               │
│                                                                                                  │
│   314 │   │   >>> default_collate_fn_map.update(CustomType, collate_customtype_fn)               │
│   315 │   │   >>> default_collate(batch)  # Handle `CustomType` automatically                    │
│   316 │   """                                                                                    │
│ ❱ 317 │   return collate(batch, collate_fn_map=default_collate_fn_map)                           │
│   318                                                                                            │
│                                                                                                  │
│ /root/torchtune/.venv/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py:155        │
│ in collate.                                                                                      │
│                                                                                                  │
│   152 │   │   │   │   # use `type(data)(...)` to create the new mapping.                         │
│   153 │   │   │   │   # Create a clone and update it if the mapping type is mutable.             │
│   154 │   │   │   │   clone = copy.copy(elem)                                                    │
│ ❱ 155 │   │   │   │   clone.update({key: collate([d[key] for d in batch], collate_fn_map=colla   │
│   156 │   │   │   │   return clone                                                               │
│   157 │   │   │   else:                                                                          │
│   158 │   │   │   │   return elem_type({key: collate([d[key] for d in batch], collate_fn_map=c   │
│                                                                                                  │
│ /root/torchtune/.venv/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py:155        │
│ in <dictcomp>.                                                                                   │
│                                                                                                  │
│   152 │   │   │   │   # use `type(data)(...)` to create the new mapping.                         │
│   153 │   │   │   │   # Create a clone and update it if the mapping type is mutable.             │
│   154 │   │   │   │   clone = copy.copy(elem)                                                    │
│ ❱ 155 │   │   │   │   clone.update({key: collate([d[key] for d in batch], collate_fn_map=colla   │
│   156 │   │   │   │   return clone                                                               │
│   157 │   │   │   else:                                                                          │
│   158 │   │   │   │   return elem_type({key: collate([d[key] for d in batch], collate_fn_map=c   │
│                                                                                                  │
│ /root/torchtune/.venv/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py:142        │
│ in collate                                                                                       │
│                                                                                                  │
│   139 │                                                                                          │
│   140 │   if collate_fn_map is not None:                                                         │
│   141 │   │   if elem_type in collate_fn_map:                                                    │
│ ❱ 142 │   │   │   return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)         │
│   143 │   │                                                                                      │
│   144 │   │   for collate_type in collate_fn_map:                                                │
│   145 │   │   │   if isinstance(elem, collate_type):                                             │
│                                                                                                  │
│ /root/torchtune/.venv/lib/python3.10/site-packages/torch/utils/data/_utils/collate.py:214        │
│ in collate_tensor_fn                                                                             │
│                                                                                                  │
│   211 │   │   numel = sum(x.numel() for x in batch)                                              │
│   212 │   │   storage = elem._typed_storage()._new_shared(numel, device=elem.device)             │
│   213 │   │   out = elem.new(storage).resize_(len(batch), *list(elem.size()))                    │
│ ❱ 214 │   return torch.stack(batch, 0, out=out)                                                  │
│   215                                                                                            │
│   216                                                                                            │
│   217 def collate_numpy_array_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Typ   │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: stack expects each tensor to be equal size, but got [4] at entry 0 and [6] at entry 1

jetstudio-io avatar Oct 15 '24 06:10 jetstudio-io

Hello, regarding the configuration, this is the config file that I used:

# Fine-tuning arguments
epochs: 1
max_steps_per_epoch: null
batch_size: 2
gradient_accumulation_steps: 8
compile: False
enable_activation_checkpointing: False
enable_activation_offloading: False
fsdp_cpu_offload: False

thanks! can you post it for both single_device vs distributed, so i can compare? For example, if even the batch size is different, that should impact your token per seconds


regarding the dataset, i will leave this one to @RdoubleA

felipemello1 avatar Oct 15 '24 14:10 felipemello1

this is my full config file. I used for both single_device & distributed

model:
  _component_: torchtune.models.phi3.lora_phi3_mini
  lora_attn_modules: ['q_proj', 'v_proj', 'k_proj', 'o_proj']
  apply_lora_to_mlp: False
  apply_lora_to_output: False
  lora_rank: 32
  lora_alpha: 64
  lora_dropout: 0.0

# Tokenizer
tokenizer:
  _component_: torchtune.models.phi3.phi3_mini_tokenizer
  path: /tmp/Phi-3-mini-4k-instruct/tokenizer.model
  max_seq_len: 4098

# Checkpointer
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/Phi-3-mini-4k-instruct
  checkpoint_files: [
    model-00001-of-00002.safetensors,
    model-00002-of-00002.safetensors
  ]
  recipe_checkpoint: null
  output_dir: /tmp/Phi-3-mini-4k-instruct
  model_type: PHI3_MINI
resume_from_checkpoint: False
save_adapter_weights_only: False

# Dataset
dataset:
  _component_: torchtune.datasets.text_completion_dataset
  source: parquet
  split: train
  data_dir: ds/sample
  batch_size: 1
  packed: False
seed: null
shuffle: True

# Fine-tuning arguments
epochs: 1
max_steps_per_epoch: null
batch_size: 1
gradient_accumulation_steps: 16
optimizer:
  _component_: torch.optim.AdamW
  fused: True
  weight_decay: 0.01
  lr: 3e-4
lr_scheduler:
  _component_: torchtune.modules.get_cosine_schedule_with_warmup
  num_warmup_steps: 100
loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss

# Training env
device: cuda

# Memory management
enable_activation_checkpointing: False
enable_activation_offloading: False

# Reduced precision
dtype: bf16

# Logging
output_dir: /tmp/phi3_lora_finetune_output
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: /tmp/Phi-3-mini-4k-instruct/logs
log_every_n_steps: 1
log_peak_memory_stats: False

I used batch_size=1 & packed=False because I got an error with batch_size=2 & packed=True as the error tracking I posted before

jetstudio-io avatar Oct 15 '24 21:10 jetstudio-io