torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

Error while full finetuning Llama 4 Scout

Open agunapal opened this issue 5 months ago • 11 comments

I am running this with nightlies

Logs

tune run --nproc_per_node 8 full_finetune_distributed --config recipes/configs/llama4/scout_17B_16E_full.yaml batch_size=4 epochs=10
Running with torchrun...
W0716 20:03:05.675000 63909 site-packages/torch/distributed/run.py:774] 
W0716 20:03:05.675000 63909 site-packages/torch/distributed/run.py:774] *****************************************
W0716 20:03:05.675000 63909 site-packages/torch/distributed/run.py:774] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0716 20:03:05.675000 63909 site-packages/torch/distributed/run.py:774] *****************************************
INFO:torchtune.utils._logging:Running FullFinetuneRecipeDistributed with resolved config:

batch_size: 4
batch_size_val: 4
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /mnt/disks/data/hf_weights/Llama-4-Scout-17B-16E-Instruct
  checkpoint_files:
    filename_format: model-{}-of-{}.safetensors
    max_filename: '00050'
  model_type: LLAMA4
  output_dir: /mnt/disks/data/logs/torchtune/llama4_17Bx16E/full
  recipe_checkpoint: null
clip_grad_norm: null
compile: false
data_parallel_replicate_dim: 1
data_parallel_shard_dim: -1
dataset:
  _component_: torchtune.datasets.chat_dataset
  conversation_column: conversations
  conversation_style: openai
  data_files: /mnt/disks/data/torchtune/data/train_top_1000_ankith.json
  packed: false
  source: json
  split: train[:95%]
  train_on_input: false
dataset_val:
  _component_: torchtune.datasets.chat_dataset
  conversation_column: conversations
  conversation_style: openai
  data_files: /mnt/disks/data/torchtune/data/train_top_1000_ankith.json
  packed: false
  source: json
  split: train[:95%]
  train_on_input: false
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false
epochs: 10
fsdp_cpu_offload: true
gradient_accumulation_steps: 1
log_every_n_steps: 1
log_level: INFO
log_peak_memory_stats: true
loss:
  _component_: torch.nn.CrossEntropyLoss
  ignore_index: -100
  reduction: mean
max_steps_per_epoch: null
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: /mnt/disks/data/logs/torchtune/llama4_17Bx16E/full/logs
model:
  _component_: torchtune.models.llama4.llama4_scout_17b_16e
optimizer:
  _component_: torch.optim.AdamW
  fused: false
  lr: 2.0e-05
optimizer_in_bwd: false
output_dir: /mnt/disks/data/logs/torchtune/llama4_17Bx16E/full
profiler:
  _component_: torchtune.training.setup_torch_profiler
  enabled: false
resume_from_checkpoint: false
run_val_every_n_steps: 60
seed: null
shuffle: true
tensor_parallel_dim: 2
tensor_parallel_plan:
  _component_: torchtune.models.llama4.decoder_only_tp_plan
tokenizer:
  _component_: torchtune.models.llama4.llama4_transform
  max_num_tiles: 16
  max_seq_len: null
  path: /mnt/disks/data/hf_weights/Llama-4-Scout-17B-16E-Instruct/tokenizer.model

INFO:torchtune.utils._logging:Running FullFinetuneRecipeDistributed with resolved config:

batch_size: 4
batch_size_val: 4
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /mnt/disks/data/hf_weights/Llama-4-Scout-17B-16E-Instruct
  checkpoint_files:
    filename_format: model-{}-of-{}.safetensors
    max_filename: '00050'
  model_type: LLAMA4
  output_dir: /mnt/disks/data/logs/torchtune/llama4_17Bx16E/full
  recipe_checkpoint: null
clip_grad_norm: null
compile: false
data_parallel_replicate_dim: 1
data_parallel_shard_dim: -1
dataset:
  _component_: torchtune.datasets.chat_dataset
  conversation_column: conversations
  conversation_style: openai
  data_files: /mnt/disks/data/torchtune/data/train_top_1000_ankith.json
  packed: false
  source: json
  split: train[:95%]
  train_on_input: false
dataset_val:
  _component_: torchtune.datasets.chat_dataset
  conversation_column: conversations
  conversation_style: openai
  data_files: /mnt/disks/data/torchtune/data/train_top_1000_ankith.json
  packed: false
  source: json
  split: train[:95%]
  train_on_input: false
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false
epochs: 10
fsdp_cpu_offload: true
gradient_accumulation_steps: 1
log_every_n_steps: 1
log_level: INFO
log_peak_memory_stats: true
loss:
  _component_: torch.nn.CrossEntropyLoss
  ignore_index: -100
  reduction: mean
max_steps_per_epoch: null
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: /mnt/disks/data/logs/torchtune/llama4_17Bx16E/full/logs
model:
  _component_: torchtune.models.llama4.llama4_scout_17b_16e
optimizer:
  _component_: torch.optim.AdamW
  fused: false
  lr: 2.0e-05
optimizer_in_bwd: false
output_dir: /mnt/disks/data/logs/torchtune/llama4_17Bx16E/full
profiler:
  _component_: torchtune.training.setup_torch_profiler
  enabled: false
resume_from_checkpoint: false
run_val_every_n_steps: 60
seed: null
shuffle: true
tensor_parallel_dim: 2
tensor_parallel_plan:
  _component_: torchtune.models.llama4.decoder_only_tp_plan
tokenizer:
  _component_: torchtune.models.llama4.llama4_transform
  max_num_tiles: 16
  max_seq_len: null
  path: /mnt/disks/data/hf_weights/Llama-4-Scout-17B-16E-Instruct/tokenizer.model

INFO:torchtune.utils._logging:Running FullFinetuneRecipeDistributed with resolved config:

batch_size: 4
batch_size_val: 4
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /mnt/disks/data/hf_weights/Llama-4-Scout-17B-16E-Instruct
  checkpoint_files:
    filename_format: model-{}-of-{}.safetensors
    max_filename: '00050'
  model_type: LLAMA4
  output_dir: /mnt/disks/data/logs/torchtune/llama4_17Bx16E/full
  recipe_checkpoint: null
clip_grad_norm: null
compile: false
data_parallel_replicate_dim: 1
data_parallel_shard_dim: -1
dataset:
  _component_: torchtune.datasets.chat_dataset
  conversation_column: conversations
  conversation_style: openai
  data_files: /mnt/disks/data/torchtune/data/train_top_1000_ankith.json
  packed: false
  source: json
  split: train[:95%]
  train_on_input: false
dataset_val:
  _component_: torchtune.datasets.chat_dataset
  conversation_column: conversations
  conversation_style: openai
  data_files: /mnt/disks/data/torchtune/data/train_top_1000_ankith.json
  packed: false
  source: json
  split: train[:95%]
  train_on_input: false
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false
epochs: 10
fsdp_cpu_offload: true
gradient_accumulation_steps: 1
log_every_n_steps: 1
log_level: INFO
log_peak_memory_stats: true
loss:
  _component_: torch.nn.CrossEntropyLoss
  ignore_index: -100
  reduction: mean
max_steps_per_epoch: null
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: /mnt/disks/data/logs/torchtune/llama4_17Bx16E/full/logs
model:
  _component_: torchtune.models.llama4.llama4_scout_17b_16e
optimizer:
  _component_: torch.optim.AdamW
  fused: false
  lr: 2.0e-05
optimizer_in_bwd: false
output_dir: /mnt/disks/data/logs/torchtune/llama4_17Bx16E/full
profiler:
  _component_: torchtune.training.setup_torch_profiler
  enabled: false
resume_from_checkpoint: false
run_val_every_n_steps: 60
seed: null
shuffle: true
tensor_parallel_dim: 2
tensor_parallel_plan:
  _component_: torchtune.models.llama4.decoder_only_tp_plan
tokenizer:
  _component_: torchtune.models.llama4.llama4_transform
  max_num_tiles: 16
  max_seq_len: null
  path: /mnt/disks/data/hf_weights/Llama-4-Scout-17B-16E-Instruct/tokenizer.model

INFO:torchtune.utils._logging:Running FullFinetuneRecipeDistributed with resolved config:

batch_size: 4
batch_size_val: 4
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /mnt/disks/data/hf_weights/Llama-4-Scout-17B-16E-Instruct
  checkpoint_files:
    filename_format: model-{}-of-{}.safetensors
    max_filename: '00050'
  model_type: LLAMA4
  output_dir: /mnt/disks/data/logs/torchtune/llama4_17Bx16E/full
  recipe_checkpoint: null
clip_grad_norm: null
compile: false
data_parallel_replicate_dim: 1
data_parallel_shard_dim: -1
dataset:
  _component_: torchtune.datasets.chat_dataset
  conversation_column: conversations
  conversation_style: openai
  data_files: /mnt/disks/data/torchtune/data/train_top_1000_ankith.json
  packed: false
  source: json
  split: train[:95%]
  train_on_input: false
dataset_val:
  _component_: torchtune.datasets.chat_dataset
  conversation_column: conversations
  conversation_style: openai
  data_files: /mnt/disks/data/torchtune/data/train_top_1000_ankith.json
  packed: false
  source: json
  split: train[:95%]
  train_on_input: false
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false
epochs: 10
fsdp_cpu_offload: true
gradient_accumulation_steps: 1
log_every_n_steps: 1
log_level: INFO
log_peak_memory_stats: true
loss:
  _component_: torch.nn.CrossEntropyLoss
  ignore_index: -100
  reduction: mean
max_steps_per_epoch: null
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: /mnt/disks/data/logs/torchtune/llama4_17Bx16E/full/logs
model:
  _component_: torchtune.models.llama4.llama4_scout_17b_16e
optimizer:
  _component_: torch.optim.AdamW
  fused: false
  lr: 2.0e-05
optimizer_in_bwd: false
output_dir: /mnt/disks/data/logs/torchtune/llama4_17Bx16E/full
profiler:
  _component_: torchtune.training.setup_torch_profiler
  enabled: false
resume_from_checkpoint: false
run_val_every_n_steps: 60
seed: null
shuffle: true
tensor_parallel_dim: 2
tensor_parallel_plan:
  _component_: torchtune.models.llama4.decoder_only_tp_plan
tokenizer:
  _component_: torchtune.models.llama4.llama4_transform
  max_num_tiles: 16
  max_seq_len: null
  path: /mnt/disks/data/hf_weights/Llama-4-Scout-17B-16E-Instruct/tokenizer.model

INFO:torchtune.utils._logging:Running FullFinetuneRecipeDistributed with resolved config:

batch_size: 4
batch_size_val: 4
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /mnt/disks/data/hf_weights/Llama-4-Scout-17B-16E-Instruct
  checkpoint_files:
    filename_format: model-{}-of-{}.safetensors
    max_filename: '00050'
  model_type: LLAMA4
  output_dir: /mnt/disks/data/logs/torchtune/llama4_17Bx16E/full
  recipe_checkpoint: null
clip_grad_norm: null
compile: false
data_parallel_replicate_dim: 1
data_parallel_shard_dim: -1
dataset:
  _component_: torchtune.datasets.chat_dataset
  conversation_column: conversations
  conversation_style: openai
  data_files: /mnt/disks/data/torchtune/data/train_top_1000_ankith.json
  packed: false
  source: json
  split: train[:95%]
  train_on_input: false
dataset_val:
  _component_: torchtune.datasets.chat_dataset
  conversation_column: conversations
  conversation_style: openai
  data_files: /mnt/disks/data/torchtune/data/train_top_1000_ankith.json
  packed: false
  source: json
  split: train[:95%]
  train_on_input: false
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false
epochs: 10
fsdp_cpu_offload: true
gradient_accumulation_steps: 1
log_every_n_steps: 1
log_level: INFO
log_peak_memory_stats: true
loss:
  _component_: torch.nn.CrossEntropyLoss
  ignore_index: -100
  reduction: mean
max_steps_per_epoch: null
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: /mnt/disks/data/logs/torchtune/llama4_17Bx16E/full/logs
model:
  _component_: torchtune.models.llama4.llama4_scout_17b_16e
optimizer:
  _component_: torch.optim.AdamW
  fused: false
  lr: 2.0e-05
optimizer_in_bwd: false
output_dir: /mnt/disks/data/logs/torchtune/llama4_17Bx16E/full
profiler:
  _component_: torchtune.training.setup_torch_profiler
  enabled: false
resume_from_checkpoint: false
run_val_every_n_steps: 60
seed: null
shuffle: true
tensor_parallel_dim: 2
tensor_parallel_plan:
  _component_: torchtune.models.llama4.decoder_only_tp_plan
tokenizer:
  _component_: torchtune.models.llama4.llama4_transform
  max_num_tiles: 16
  max_seq_len: null
  path: /mnt/disks/data/hf_weights/Llama-4-Scout-17B-16E-Instruct/tokenizer.model

INFO:torchtune.utils._logging:Running FullFinetuneRecipeDistributed with resolved config:

batch_size: 4
batch_size_val: 4
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /mnt/disks/data/hf_weights/Llama-4-Scout-17B-16E-Instruct
  checkpoint_files:
    filename_format: model-{}-of-{}.safetensors
    max_filename: '00050'
  model_type: LLAMA4
  output_dir: /mnt/disks/data/logs/torchtune/llama4_17Bx16E/full
  recipe_checkpoint: null
clip_grad_norm: null
compile: false
data_parallel_replicate_dim: 1
data_parallel_shard_dim: -1
dataset:
  _component_: torchtune.datasets.chat_dataset
  conversation_column: conversations
  conversation_style: openai
  data_files: /mnt/disks/data/torchtune/data/train_top_1000_ankith.json
  packed: false
  source: json
  split: train[:95%]
  train_on_input: false
dataset_val:
  _component_: torchtune.datasets.chat_dataset
  conversation_column: conversations
  conversation_style: openai
  data_files: /mnt/disks/data/torchtune/data/train_top_1000_ankith.json
  packed: false
  source: json
  split: train[:95%]
  train_on_input: false
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false
epochs: 10
fsdp_cpu_offload: true
gradient_accumulation_steps: 1
log_every_n_steps: 1
log_level: INFO
log_peak_memory_stats: true
loss:
  _component_: torch.nn.CrossEntropyLoss
  ignore_index: -100
  reduction: mean
max_steps_per_epoch: null
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: /mnt/disks/data/logs/torchtune/llama4_17Bx16E/full/logs
model:
  _component_: torchtune.models.llama4.llama4_scout_17b_16e
optimizer:
  _component_: torch.optim.AdamW
  fused: false
  lr: 2.0e-05
optimizer_in_bwd: false
output_dir: /mnt/disks/data/logs/torchtune/llama4_17Bx16E/full
profiler:
  _component_: torchtune.training.setup_torch_profiler
  enabled: false
resume_from_checkpoint: false
run_val_every_n_steps: 60
seed: null
shuffle: true
tensor_parallel_dim: 2
tensor_parallel_plan:
  _component_: torchtune.models.llama4.decoder_only_tp_plan
tokenizer:
  _component_: torchtune.models.llama4.llama4_transform
  max_num_tiles: 16
  max_seq_len: null
  path: /mnt/disks/data/hf_weights/Llama-4-Scout-17B-16E-Instruct/tokenizer.model

INFO:torchtune.utils._logging:Running FullFinetuneRecipeDistributed with resolved config:

batch_size: 4
batch_size_val: 4
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /mnt/disks/data/hf_weights/Llama-4-Scout-17B-16E-Instruct
  checkpoint_files:
    filename_format: model-{}-of-{}.safetensors
    max_filename: '00050'
  model_type: LLAMA4
  output_dir: /mnt/disks/data/logs/torchtune/llama4_17Bx16E/full
  recipe_checkpoint: null
clip_grad_norm: null
compile: false
data_parallel_replicate_dim: 1
data_parallel_shard_dim: -1
dataset:
  _component_: torchtune.datasets.chat_dataset
  conversation_column: conversations
  conversation_style: openai
  data_files: /mnt/disks/data/torchtune/data/train_top_1000_ankith.json
  packed: false
  source: json
  split: train[:95%]
  train_on_input: false
dataset_val:
  _component_: torchtune.datasets.chat_dataset
  conversation_column: conversations
  conversation_style: openai
  data_files: /mnt/disks/data/torchtune/data/train_top_1000_ankith.json
  packed: false
  source: json
  split: train[:95%]
  train_on_input: false
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false
epochs: 10
fsdp_cpu_offload: true
gradient_accumulation_steps: 1
log_every_n_steps: 1
log_level: INFO
log_peak_memory_stats: true
loss:
  _component_: torch.nn.CrossEntropyLoss
  ignore_index: -100
  reduction: mean
max_steps_per_epoch: null
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: /mnt/disks/data/logs/torchtune/llama4_17Bx16E/full/logs
model:
  _component_: torchtune.models.llama4.llama4_scout_17b_16e
optimizer:
  _component_: torch.optim.AdamW
  fused: false
  lr: 2.0e-05
optimizer_in_bwd: false
output_dir: /mnt/disks/data/logs/torchtune/llama4_17Bx16E/full
profiler:
  _component_: torchtune.training.setup_torch_profiler
  enabled: false
resume_from_checkpoint: false
run_val_every_n_steps: 60
seed: null
shuffle: true
tensor_parallel_dim: 2
tensor_parallel_plan:
  _component_: torchtune.models.llama4.decoder_only_tp_plan
tokenizer:
  _component_: torchtune.models.llama4.llama4_transform
  max_num_tiles: 16
  max_seq_len: null
  path: /mnt/disks/data/hf_weights/Llama-4-Scout-17B-16E-Instruct/tokenizer.model

INFO:torchtune.utils._logging:Running FullFinetuneRecipeDistributed with resolved config:

batch_size: 4
batch_size_val: 4
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /mnt/disks/data/hf_weights/Llama-4-Scout-17B-16E-Instruct
  checkpoint_files:
    filename_format: model-{}-of-{}.safetensors
    max_filename: '00050'
  model_type: LLAMA4
  output_dir: /mnt/disks/data/logs/torchtune/llama4_17Bx16E/full
  recipe_checkpoint: null
clip_grad_norm: null
compile: false
data_parallel_replicate_dim: 1
data_parallel_shard_dim: -1
dataset:
  _component_: torchtune.datasets.chat_dataset
  conversation_column: conversations
  conversation_style: openai
  data_files: /mnt/disks/data/torchtune/data/train_top_1000_ankith.json
  packed: false
  source: json
  split: train[:95%]
  train_on_input: false
dataset_val:
  _component_: torchtune.datasets.chat_dataset
  conversation_column: conversations
  conversation_style: openai
  data_files: /mnt/disks/data/torchtune/data/train_top_1000_ankith.json
  packed: false
  source: json
  split: train[:95%]
  train_on_input: false
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false
epochs: 10
fsdp_cpu_offload: true
gradient_accumulation_steps: 1
log_every_n_steps: 1
log_level: INFO
log_peak_memory_stats: true
loss:
  _component_: torch.nn.CrossEntropyLoss
  ignore_index: -100
  reduction: mean
max_steps_per_epoch: null
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: /mnt/disks/data/logs/torchtune/llama4_17Bx16E/full/logs
model:
  _component_: torchtune.models.llama4.llama4_scout_17b_16e
optimizer:
  _component_: torch.optim.AdamW
  fused: false
  lr: 2.0e-05
optimizer_in_bwd: false
output_dir: /mnt/disks/data/logs/torchtune/llama4_17Bx16E/full
profiler:
  _component_: torchtune.training.setup_torch_profiler
  enabled: false
resume_from_checkpoint: false
run_val_every_n_steps: 60
seed: null
shuffle: true
tensor_parallel_dim: 2
tensor_parallel_plan:
  _component_: torchtune.models.llama4.decoder_only_tp_plan
tokenizer:
  _component_: torchtune.models.llama4.llama4_transform
  max_num_tiles: 16
  max_seq_len: null
  path: /mnt/disks/data/hf_weights/Llama-4-Scout-17B-16E-Instruct/tokenizer.model

[Gloo] Rank 0 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 1 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 3 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 4 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 5 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank 7 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[Gloo] Rank [Gloo] Rank 62 is connected to  is connected to 77 peer ranks.  peer ranks. Expected number of connected peer ranks is : Expected number of connected peer ranks is : 77

[Gloo] Rank 1 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank 2 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank 0 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank 0 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank 3 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank 1 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank [Gloo] Rank 3 is connected to 3 peer ranks. 2Expected number of connected peer ranks is :  is connected to 33
 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank [Gloo] Rank [Gloo] Rank 100 is connected to  is connected to  is connected to 111 peer ranks.  peer ranks.  peer ranks. Expected number of connected peer ranks is : Expected number of connected peer ranks is : Expected number of connected peer ranks is : 111


[Gloo] Rank 0 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank 3 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank [Gloo] Rank 1 is connected to 3 peer ranks. 2Expected number of connected peer ranks is :  is connected to 33
 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank 0 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank 1 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank 3 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank 2 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank 0 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank 2 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank 0 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank 1 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank 3 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank 1 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank 0 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank [Gloo] Rank 3 is connected to 32 peer ranks.  is connected to Expected number of connected peer ranks is : 33 peer ranks. 
Expected number of connected peer ranks is : 3
[Gloo] Rank 1 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank [Gloo] Rank 3 is connected to 23 is connected to  peer ranks. 3Expected number of connected peer ranks is :  peer ranks. 3Expected number of connected peer ranks is : 
3
[Gloo] Rank 2 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank 0 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank 1 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
[Gloo] Rank 3 is connected to 3 peer ranks. Expected number of connected peer ranks is : 3
INFO:torchtune.utils._logging:Hint: enable_activation_checkpointing is True, but enable_activation_offloading isn't. Enabling activation offloading should reduce memory further.
INFO:torchtune.utils._logging:Set intra op parallelism no. of threads to 26
INFO:torchtune.utils._logging:Set intra op parallelism no. of threads to 26
INFO:torchtune.utils._logging:Set intra op parallelism no. of threads to 26
INFO:torchtune.utils._logging:Set intra op parallelism no. of threads to 26
INFO:torchtune.utils._logging:Set intra op parallelism no. of threads to 26
INFO:torchtune.utils._logging:Set intra op parallelism no. of threads to 26
INFO:torchtune.utils._logging:Set intra op parallelism no. of threads to 26
INFO:torchtune.utils._logging:Set intra op parallelism no. of threads to 26
Writing logs to /mnt/disks/data/logs/torchtune/llama4_17Bx16E/full/logs/log_1752696194.txt
INFO:torchtune.utils._logging:Distributed training is enabled. Instantiating model and loading checkpoint on Rank 0 ...
INFO:torchtune.utils._logging:Instantiating model and loading checkpoint took 410.16 secs
INFO:torchtune.utils._logging:Memory stats after model init:
	GPU peak memory active: 12.54 GiB
	GPU peak memory alloc: 12.54 GiB
	GPU peak memory reserved: 12.55 GiB
INFO:torchtune.utils._logging:Optimizer is initialized.
INFO:torchtune.utils._logging:Loss is initialized.
INFO:torchtune.utils._logging:No learning rate scheduler configured. Using constant learning rate.
WARNING:torchtune.utils._logging: Profiling disabled.
INFO:torchtune.utils._logging: Profiler config after instantiation: {'enabled': False}

Error



1|59|Loss: 0.8372806906700134: 100%|██████████████████████████████████████████████████████████| 59/59 [36:29<00:00, 33.91s/it]INFO:torchtune.utils._logging:Saving checkpoint. This may take some time. Retrieving full model state dict...
INFO:torchtune.utils._logging:Getting full model state dict took 70.13 secs
INFO:torchtune.utils._logging:Getting optimizer state dict...
INFO:torchtune.utils._logging:Getting optimizer state dict took 291.04 secs
[rank7]: Traceback (most recent call last):
[rank7]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/recipes/full_finetune_distributed.py", line 1117, in <module>
[rank7]:     sys.exit(recipe_main())
[rank7]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torchtune/config/_parse.py", line 99, in wrapper
[rank7]:     sys.exit(recipe_main(conf))
[rank7]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/recipes/full_finetune_distributed.py", line 1112, in recipe_main
[rank7]:     recipe.train()
[rank7]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/recipes/full_finetune_distributed.py", line 1075, in train
[rank7]:     self._checkpoint_client.save_checkpoint(
[rank7]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torchtune/training/checkpointing/_checkpoint_client.py", line 393, in save_checkpoint
[rank7]:     self._save_checkpoint_sync(
[rank7]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torchtune/training/checkpointing/_checkpoint_client.py", line 355, in _save_checkpoint_sync
[rank7]:     torch.distributed.barrier()
[rank7]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
[rank7]:     return func(*args, **kwargs)
[rank7]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 4820, in barrier
[rank7]:     work.wait()
[rank7]: RuntimeError: [/pytorch/third_party/gloo/gloo/transport/tcp/pair.cc:544] Connection closed by peer [10.138.0.25]:30610
[rank1]: Traceback (most recent call last):
[rank1]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/recipes/full_finetune_distributed.py", line 1117, in <module>
[rank1]:     sys.exit(recipe_main())
[rank1]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torchtune/config/_parse.py", line 99, in wrapper
[rank1]:     sys.exit(recipe_main(conf))
[rank1]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/recipes/full_finetune_distributed.py", line 1112, in recipe_main
[rank1]:     recipe.train()
[rank1]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/recipes/full_finetune_distributed.py", line 1075, in train
[rank1]:     self._checkpoint_client.save_checkpoint(
[rank1]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torchtune/training/checkpointing/_checkpoint_client.py", line 393, in save_checkpoint
[rank1]:     self._save_checkpoint_sync(
[rank1]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torchtune/training/checkpointing/_checkpoint_client.py", line 355, in _save_checkpoint_sync
[rank1]:     torch.distributed.barrier()
[rank1]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
[rank1]:     return func(*args, **kwargs)
[rank1]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 4820, in barrier
[rank1]:     work.wait()
[rank1]: RuntimeError: [/pytorch/third_party/gloo/gloo/transport/tcp/pair.cc:544] Connection closed by peer [10.138.0.25]:30610
[rank4]: Traceback (most recent call last):
[rank4]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/recipes/full_finetune_distributed.py", line 1117, in <module>
[rank4]:     sys.exit(recipe_main())
[rank4]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torchtune/config/_parse.py", line 99, in wrapper
[rank4]:     sys.exit(recipe_main(conf))
[rank4]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/recipes/full_finetune_distributed.py", line 1112, in recipe_main
[rank4]:     recipe.train()
[rank4]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/recipes/full_finetune_distributed.py", line 1075, in train
[rank4]:     self._checkpoint_client.save_checkpoint(
[rank4]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torchtune/training/checkpointing/_checkpoint_client.py", line 393, in save_checkpoint
[rank4]:     self._save_checkpoint_sync(
[rank4]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torchtune/training/checkpointing/_checkpoint_client.py", line 355, in _save_checkpoint_sync
[rank4]:     torch.distributed.barrier()
[rank4]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
[rank4]:     return func(*args, **kwargs)
[rank4]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 4820, in barrier
[rank4]:     work.wait()
[rank4]: RuntimeError: [/pytorch/third_party/gloo/gloo/transport/tcp/pair.cc:544] Connection closed by peer [10.138.0.25]:30610
[rank2]: Traceback (most recent call last):
[rank2]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/recipes/full_finetune_distributed.py", line 1117, in <module>
[rank2]:     sys.exit(recipe_main())
[rank2]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torchtune/config/_parse.py", line 99, in wrapper
[rank2]:     sys.exit(recipe_main(conf))
[rank2]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/recipes/full_finetune_distributed.py", line 1112, in recipe_main
[rank2]:     recipe.train()
[rank2]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/recipes/full_finetune_distributed.py", line 1075, in train
[rank2]:     self._checkpoint_client.save_checkpoint(
[rank2]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torchtune/training/checkpointing/_checkpoint_client.py", line 393, in save_checkpoint
[rank2]:     self._save_checkpoint_sync(
[rank2]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torchtune/training/checkpointing/_checkpoint_client.py", line 355, in _save_checkpoint_sync
[rank2]:     torch.distributed.barrier()
[rank2]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
[rank2]:     return func(*args, **kwargs)
[rank2]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 4820, in barrier
[rank2]:     work.wait()
[rank2]: RuntimeError: [/pytorch/third_party/gloo/gloo/transport/tcp/pair.cc:544] Connection closed by peer [10.138.0.25]:42181
[rank6]: Traceback (most recent call last):
[rank6]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/recipes/full_finetune_distributed.py", line 1117, in <module>
[rank6]:     sys.exit(recipe_main())
[rank6]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torchtune/config/_parse.py", line 99, in wrapper
[rank6]:     sys.exit(recipe_main(conf))
[rank6]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/recipes/full_finetune_distributed.py", line 1112, in recipe_main
[rank6]:     recipe.train()
[rank6]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/recipes/full_finetune_distributed.py", line 1075, in train
[rank6]:     self._checkpoint_client.save_checkpoint(
[rank6]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torchtune/training/checkpointing/_checkpoint_client.py", line 393, in save_checkpoint
[rank6]:     self._save_checkpoint_sync(
[rank6]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torchtune/training/checkpointing/_checkpoint_client.py", line 355, in _save_checkpoint_sync
[rank6]:     torch.distributed.barrier()
[rank6]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
[rank6]:     return func(*args, **kwargs)
[rank6]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 4820, in barrier
[rank6]:     work.wait()
[rank6]: RuntimeError: [/pytorch/third_party/gloo/gloo/transport/tcp/pair.cc:544] Connection closed by peer [10.138.0.25]:30610
[rank4]:[W716 20:53:12.395308535 ProcessGroupNCCL.cpp:1566] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank2]:[W716 20:53:12.423445497 ProcessGroupNCCL.cpp:1566] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank6]:[W716 20:53:14.958723379 ProcessGroupNCCL.cpp:1566] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank1]:[W716 20:53:14.022187294 ProcessGroupNCCL.cpp:1566] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
[rank3]: Traceback (most recent call last):
[rank3]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/recipes/full_finetune_distributed.py", line 1117, in <module>
[rank3]:     sys.exit(recipe_main())
[rank3]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torchtune/config/_parse.py", line 99, in wrapper
[rank3]:     sys.exit(recipe_main(conf))
[rank3]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/recipes/full_finetune_distributed.py", line 1112, in recipe_main
[rank3]:     recipe.train()
[rank3]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/recipes/full_finetune_distributed.py", line 1075, in train
[rank3]:     self._checkpoint_client.save_checkpoint(
[rank3]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torchtune/training/checkpointing/_checkpoint_client.py", line 393, in save_checkpoint
[rank3]:     self._save_checkpoint_sync(
[rank3]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torchtune/training/checkpointing/_checkpoint_client.py", line 355, in _save_checkpoint_sync
[rank3]:     torch.distributed.barrier()
[rank3]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
[rank3]:     return func(*args, **kwargs)
[rank3]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 4820, in barrier
[rank3]:     work.wait()
[rank3]: RuntimeError: [/pytorch/third_party/gloo/gloo/transport/tcp/pair.cc:544] Connection closed by peer [10.138.0.25]:36068
[rank5]: Traceback (most recent call last):
[rank5]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/recipes/full_finetune_distributed.py", line 1117, in <module>
[rank5]:     sys.exit(recipe_main())
[rank5]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torchtune/config/_parse.py", line 99, in wrapper
[rank5]:     sys.exit(recipe_main(conf))
[rank5]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/recipes/full_finetune_distributed.py", line 1112, in recipe_main
[rank5]:     recipe.train()
[rank5]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/recipes/full_finetune_distributed.py", line 1075, in train
[rank5]:     self._checkpoint_client.save_checkpoint(
[rank5]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torchtune/training/checkpointing/_checkpoint_client.py", line 393, in save_checkpoint
[rank5]:     self._save_checkpoint_sync(
[rank5]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torchtune/training/checkpointing/_checkpoint_client.py", line 355, in _save_checkpoint_sync
[rank5]:     torch.distributed.barrier()
[rank5]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
[rank5]:     return func(*args, **kwargs)
[rank5]:   File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 4820, in barrier
[rank5]:     work.wait()
[rank5]: RuntimeError: [/pytorch/third_party/gloo/gloo/transport/tcp/pair.cc:544] Connection closed by peer [10.138.0.25]:61150
W0716 20:53:26.982000 63909 site-packages/torch/distributed/elastic/multiprocessing/api.py:900] Sending process 64189 closing signal SIGTERM
W0716 20:53:26.983000 63909 site-packages/torch/distributed/elastic/multiprocessing/api.py:900] Sending process 64190 closing signal SIGTERM
W0716 20:53:26.983000 63909 site-packages/torch/distributed/elastic/multiprocessing/api.py:900] Sending process 64191 closing signal SIGTERM
W0716 20:53:26.984000 63909 site-packages/torch/distributed/elastic/multiprocessing/api.py:900] Sending process 64192 closing signal SIGTERM
W0716 20:53:26.985000 63909 site-packages/torch/distributed/elastic/multiprocessing/api.py:900] Sending process 64193 closing signal SIGTERM
W0716 20:53:26.988000 63909 site-packages/torch/distributed/elastic/multiprocessing/api.py:900] Sending process 64194 closing signal SIGTERM
W0716 20:53:26.989000 63909 site-packages/torch/distributed/elastic/multiprocessing/api.py:900] Sending process 64195 closing signal SIGTERM
W0716 20:53:56.990000 63909 site-packages/torch/distributed/elastic/multiprocessing/api.py:919] Unable to shutdown process 64191 via Signals.SIGTERM, forcefully exiting via Signals.SIGKILL
W0716 20:54:02.560000 63909 site-packages/torch/distributed/elastic/multiprocessing/api.py:919] Unable to shutdown process 64192 via Signals.SIGTERM, forcefully exiting via Signals.SIGKILL
E0716 20:54:14.628000 63909 site-packages/torch/distributed/elastic/multiprocessing/api.py:874] failed (exitcode: -9) local_rank: 0 (pid: 64188) of binary: /opt/conda/envs/torchtune/bin/python3.10
Traceback (most recent call last):
  File "/opt/conda/envs/torchtune/bin/tune", line 8, in <module>
    sys.exit(main())
  File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torchtune/_cli/tune.py", line 52, in main
    parser.run(args)
  File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torchtune/_cli/tune.py", line 46, in run
    args.func(args)
  File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torchtune/_cli/run.py", line 212, in _run_cmd
    self._run_distributed(args, is_builtin=is_builtin)
  File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 357, in wrapper
    return f(*args, **kwargs)
  File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torchtune/_cli/run.py", line 101, in _run_distributed
    run(args)
  File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torch/distributed/run.py", line 892, in run
    elastic_launch(
  File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 143, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/opt/conda/envs/torchtune/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 277, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
/opt/conda/envs/torchtune/lib/python3.10/site-packages/recipes/full_finetune_distributed.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------

agunapal avatar Jul 16 '25 22:07 agunapal

Would appreciate any pointers on the last working commit for Llama 4

agunapal avatar Jul 16 '25 22:07 agunapal

Hey! Interesting issue, I was attempting to fix #2856 but this looks like some separate issue with Llama4 :/ I will investigate more in this direction

krammnic avatar Jul 17 '25 11:07 krammnic

It looks like the rank zero device is getting stuck somewhere in this helper function: https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpoint_client.py#L355

Can you add some logs in there to see what's going on?

calvinpelletier avatar Aug 18 '25 19:08 calvinpelletier

@mreso has been helping me root cause this. The CPU is going out of memory. This looks like a memory leak issue.

I am not able to reproduce the issue consistently with open source data yet, but looks like even with 5 steps and batch size 4 , for a multi-turn chat dataset with a pretty big variable system prompt (< 10k ), possibly even padding comes into picture, we end up using progressively increasing CPU memory and eventually using 1.2 TB of CPU memory (out of 1.8TB ) even before checkpointing begins and eventually CPU OOM while checkpointing. 1.2 TB seems abnormally high?

@calvinpelletier

agunapal avatar Aug 20 '25 04:08 agunapal

Yes that is abnormally high, I'm digging into this

calvinpelletier avatar Aug 21 '25 16:08 calvinpelletier

@joecummings should Llama4 Scout use more CPU memory than LLama3.3 70B?

When full finetuning Scout (tune run --nproc_per_node 8 full_finetune_distributed --config llama4/scout_17B_16E_full on 8xA100s), CPU mem is constant at 500 GB while training, then checkpointing causes it to climb to 800 GB. Then getting then optimizer state dict takes a long time and it never gets past that point because eventually the CPU mem spikes hard and it crashes.

For comparison, full finetuning LLama3.3 70B (tune run --nproc_per_node 8 full_finetune_distributed --config llama3_3/70B_full) only uses ~350 GB during training and climbs to 750 GB during checkpointing and successfully checkpoints without OOMing.

Image

calvinpelletier avatar Aug 26 '25 00:08 calvinpelletier

@pradeepfn coud you take a look

ekr0 avatar Aug 26 '25 20:08 ekr0

1\ In general this is a checkpointing issue. High level triage is correct. 2\ However as per the above evidence, the memory spike happens during state-dict preparation (before calling dcp.save routines). https://github.com/pytorch/torchtune/blob/67ab86b94de9e7ac7dd9850113ebe69e2bbd307c/torchtune/training/checkpointing/_checkpoint_client.py#L335

Therefore, for the record, this is a not a DCP originated issue. ( more likely happening during state-dict prep step of pytorch distributed utils). That is torch.distributed.state_dict.py

pradeepfn avatar Sep 02 '25 21:09 pradeepfn

https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpoint_client.py#L344-L346

This will all-gather the optim state dict on ranks which could lead to high memory usage. Is this desired? @pradeepfn @calvinpelletier

XilunWu avatar Sep 02 '25 23:09 XilunWu

  1. Existing HF formatted checkpoints in touchtone are not distributed. Therefore, all the distributed state is gathered to rank 0, prior to calling dcp.save. This has high OOMing risk in rank0. This is likely the situation we are experiencing.
  2. Checkpointing without OOMing. - We should use distributed checkpoint. In distributed checkpoints, we do not gather model state to rank0. Instead each training rank writes to storage their shard. However, there is a blocker to enable HF formatted distributed checkpointing in torchTune — it does not have distributed HF checkpointing support yet.
    • Recently Ankita from DCP team started introducing that capability/integration. But never landed https://github.com/pytorch/torchtune/pull/2851/files
    • If we are to go in this route, we should first enable distributed checkpointing without HF format ( this one is already supported) and do a run. It will give us a signal that, distributed checkpointing in fact solves the OOMin issue. Afterwards we can look in to integrating HF support in distributed checkpoints. ( above PR).

pradeepfn avatar Sep 03 '25 18:09 pradeepfn

Hi @agunapal quick update, I rebased Ankita's PR to main as the API have become incompatible which saves a lot of memory during the training/save checkpoints phase. Only issue was that the loading of the weights still created a spike which lead to OOM on a machine with 1.5TB RAM. Here is a memory utilization plot on a 2.3TB machine where is ran successful:

Image

According to Ankita this is due to the fact that she did not implement dcp for the loading portion of the HFCheckpointer so its gets the whole state_dict which results in that spike. As a hacky workaround I route the loading portion of the checkpointer through the non-dcp path which seems to avoid the memory spike. I can now successfully train and save the model on a machine with 1.5TB RAM.

Memory utilization on 2.3TB machine:

Image

1.5 TB machine (just two epochs):

Image

You can try this version by installing from this branch and add these to entries to your checkpointer:

checkpointer:
  # ...
  intermediate_hf_dir_dcp: ${output_dir}/hf_dir
  enable_dcp: True

mreso avatar Sep 16 '25 04:09 mreso