composer icon indicating copy to clipboard operation
composer copied to clipboard

Stuck when continuing training with auto_resume: true

Open timpal0l opened this issue 6 months ago • 7 comments
trafficstars

Pretraining a large modernbert model on 1.1 trillion tokens

I am training on a multi-node setup with the following yaml. And it works fine, except that when I run it for the second time with auto_resume: true (to load the latest checkpoint, it just loads forever and nothing happens. I can see it allocates around 7gb vram on the gpu nodes. It allocates 29gb vram when running succesfully.

I only see:

Starting training...

Training yaml:

# Data paths
data_local: /project/scratch/p200667/dataset
data_remote: # If blank, files must be present in data_local

# Sequence & tokenizer
max_seq_len: 8192
tokenizer_name: answerdotai/ModernBERT-large
mlm_probability: 0.3
count_padding_tokens: false

# Run Name
run_name: modernbert-large-pretrain

# Model
model:
  name: flex_bert
  pretrained_model_name: bert-base-uncased # has to be set to bert-base-uncased legacy reasons
  tokenizer_name: ${tokenizer_name}
  disable_train_metrics: true # save some time by not computing metrics on the training set
  model_config:
    vocab_size: 50368
    init_method: full_megatron
    num_hidden_layers: 28
    hidden_size: 1024
    intermediate_size: 2624
    num_attention_heads: 16 # to have head size of 64
    attention_layer: rope
    attention_probs_dropout_prob: 0.0
    attn_out_bias: false
    attn_out_dropout_prob: 0.1
    attn_qkv_bias: false
    bert_layer: prenorm
    embed_dropout_prob: 0.0
    embed_norm: true
    final_norm: true
    skip_first_prenorm: true
    embedding_layer: sans_pos
    loss_function: fa_cross_entropy
    loss_kwargs:
      reduction: mean
    mlp_dropout_prob: 0.0
    mlp_in_bias: false
    mlp_layer: glu
    mlp_out_bias: false
    normalization: layernorm
    norm_kwargs:
      eps: 1e-5
      bias: false
    hidden_act: gelu
    head_pred_act: gelu
    activation_function: gelu # better safe than sorry
    padding: unpadded
    rotary_emb_dim: null
    rotary_emb_base: 160000.0
    rotary_emb_scale_base: null
    rotary_emb_interleaved: false
    local_attn_rotary_emb_base: 10000.0
    local_attn_rotary_emb_dim: null
    allow_embedding_resizing: true
    sliding_window: 128
    global_attn_every_n_layers: 3
    unpad_embeddings: true
    compile_model: true
    masked_prediction: true

# Dataloaders
train_loader:
  name: text
  dataset:
    local: ${data_local}
    remote: ${data_remote}
    split:
    tokenizer_name: ${tokenizer_name}
    max_seq_len: ${max_seq_len}
    shuffle: true
    mlm_probability: ${mlm_probability}
    streaming: false
  drop_last: true
  num_workers: 24
  sequence_packing: true
  persistent_workers: false
  pin_memory: false

# Optimization
scheduler:
  name: cosine_with_warmup
  t_warmup: 11_985_103_472tok # 1% of total ds
  t_max: ${max_duration}
  alpha_f: 0.001

optimizer:
  name: decoupled_stableadamw
  lr: 2e-4
  betas:
  - 0.9
  - 0.98
  eps: 1.0e-06
  weight_decay: 1.0e-5
  filter_bias_norm_wd: true
  log_grad_norm: true

# Training duration & batch sizes
max_duration: 1_198_510_347_252tok
eval_interval: 0
global_train_batch_size: 256
global_eval_batch_size: 256

# System settings
seed: 420
device_eval_batch_size: 2
device_train_microbatch_size: 2
precision: amp_bf16

# Logging & callbacks
progress_bar: false
log_to_console: true
console_log_interval: 10ba

callbacks:
  speed_monitor:
    window_size: 10
  lr_monitor: {}
  scheduled_gc: {}
  log_grad_norm:
    batch_log_interval: 10
  packing_efficiency:
    log_interval: 10

# Checkpointing
save_interval: 5000ba
save_num_checkpoints_to_keep: 3
save_folder: checkpoints/{run_name}

#load_path: checkpoints/${run_name}/ep0-ba30000-rank0.pt
autoresume: true
#reset_time: false

#loggers:
#  wandb:
#    project: modernbert-large-pretrain
#    entity: tim

I launch it with a bash script on a hpc cluster using slurm on 16 nodes:

#!/bin/bash -l
#SBATCH -A p200XXX
#SBATCH -p gpu
#SBATCH --qos=default
#SBATCH --nodes=16               # total nodes
#SBATCH --ntasks-per-node=1      # one Slurm task per node
#SBATCH --gres=gpu:4             # 4 GPUs per node
#SBATCH --cpus-per-task=128
#SBATCH --time=10:00:00
#SBATCH --job-name=modernbert_ddp
#SBATCH --output=logs/modernbert_%j.out
#SBATCH --error=logs/modernbert_%j.err
#SBATCH --exclusive              # no other jobs on these nodes
#SBATCH --wait-all-nodes=1       # start only once all nodes are up

set -eo pipefail
ulimit -n 8192

#module purge
module load NCCL/2.22.3-GCCcore-13.3.0-CUDA-12.6.0

# Ensure we use the system NCCL
export LD_LIBRARY_PATH=$EBROOTNCCL/lib:$LD_LIBRARY_PATH

source /project/scratch/p200667/miniconda/etc/profile.d/conda.sh
conda activate bert2026

# Logging / debug
export LOGLEVEL=INFO
export NCCL_DEBUG=TRACE
export TORCH_CPP_LG_LEVEL=INFO
# export NCCL_TIMEOUT=900
# NCCL / IB-verbs tunables
export NCCL_SOCKET_IFNAME=ib0
# export NCCL_IB_HCA=mlx5_0
export TORCH_NCCL_ASYNC_ERROR_HANDLING=1

export OMP_NUM_THREADS=8

# Gather node list and compute world‐size
NODES=( $(scontrol show hostnames $SLURM_JOB_NODELIST) )
NNODES=${#NODES[@]}
HEAD_NODE=${NODES[0]}

# Derive the IP of the head node on ib0
MASTER_ADDR=$( srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address )
MASTER_PORT=$(( RANDOM % 10000 + 20000 ))  # random port in [20000–29999]

# DDP parameters
NPROC=4
WORLD_SIZE=$(( NNODES * NPROC ))

CONFIG=/project/scratch/p200667/timpal0l/ModernBERT/training/modernbert-large-learning-rate-decay.yaml

echo "=== DDP CONFIGURATION ==="
echo "  NODES:           ${NODES[*]}"
echo "  NNODES:          $NNODES"
echo "  HEAD_NODE:       $HEAD_NODE"
echo "  MASTER_ADDR:     $MASTER_ADDR"
echo "  MASTER_PORT:     $MASTER_PORT"
echo "  NPROC per node:  $NPROC"
echo "  WORLD_SIZE:      $WORLD_SIZE"
echo "  CONFIG:          $CONFIG"
echo "========================="

# Function to launch Composer on a given node‐rank
run_compose() {
    local NODE=$1
    local NODE_RANK=$2

    echo ">>> Launching node_rank=$NODE_RANK on $NODE"
    srun --nodelist=$NODE \
          --ntasks=1 \
          --cpus-per-task=128 \
          --gres=gpu:4 \
        composer \
          --nproc        $NPROC \
          --world_size   $WORLD_SIZE \
          --node_rank    $NODE_RANK \
          --base_rank    $(( NODE_RANK * NPROC )) \
          --master_addr  $MASTER_ADDR \
          --master_port  $MASTER_PORT \
          --stdout	 logs/modernbert_${SLURM_JOB_ID}_rank{rank}.out \
          --stderr	 logs/modernbert_${SLURM_JOB_ID}_rank{rank}.err \
          --verbose \
        main.py "$CONFIG"
}

# Launch worker nodes (ranks 1 … NNODES–1) in the background
for (( NODE_RANK=1; NODE_RANK<NNODES; NODE_RANK++ )); do
    run_compose "${NODES[$NODE_RANK]}" $NODE_RANK &
done

# Launch head node as rank 0 in the foreground
run_compose $HEAD_NODE 0

# Wait for all background launches to finish
wait
echo "All ranks finished."

The main.py

[u102342@mel2035 ~]$ pip list
annotated-types             0.7.0
antlr4-python3-runtime      4.9.3
anyio                       4.9.0
azure-core                  1.34.0
azure-identity              1.23.0
azure-storage-blob          12.25.1
azure-storage-file-datalake 12.20.0
backoff                     2.2.1
bcrypt                      4.3.0
boto3                       1.38.17
botocore                    1.38.17
Brotli                      1.1.0
cachetools                  5.5.2
certifi                     2025.4.26
cffi                        1.17.1
charset-normalizer          3.4.2
circuitbreaker              2.1.3
click                       8.2.0
contourpy                   1.3.2
coolname                    2.2.0
cramjam                     2.10.0
cryptography                44.0.3
cycler                      0.12.1
docker-pycreds              0.4.0
einops                      0.8.0
filelock                    3.18.0
flash-attn                  2.6.3
fonttools                   4.58.0
fsspec                      2025.3.2
gitdb                       4.0.12
GitPython                   3.1.44
google-api-core             2.25.0rc1
google-auth                 2.40.1
google-cloud-core           2.4.3
google-cloud-storage        2.10.0
google-crc32c               1.7.1
google-resumable-media      2.7.2
googleapis-common-protos    1.70.0
gql                         3.5.2
graphql-core                3.2.4
huggingface-hub             0.31.2
idna                        3.10
importlib_metadata          8.4.0
isodate                     0.7.2
Jinja2                      3.1.6
jmespath                    1.0.1
kiwisolver                  1.4.8
lightning-utilities         0.14.3
llvmlite                    0.43.0
markdown-it-py              3.0.0
MarkupSafe                  3.0.2
matplotlib                  3.10.3
mdurl                       0.1.2
mosaicml                    0.30.0
mosaicml-cli                0.7.2
mosaicml-streaming          0.7.6
mpmath                      1.3.0
msal                        1.32.3
msal-extensions             1.3.1
multidict                   6.4.3
myquota                     0.3.3
networkx                    3.4.2
ninja                       1.11.1.4
numba                       0.60.0
numpy                       2.0.2
oci                         2.152.0
omegaconf                   2.3.0
packaging                   25.0
paramiko                    3.5.1
pillow                      11.2.1
pip                         25.1
platformdirs                4.3.8
prompt_toolkit              3.0.51
propcache                   0.3.1
proto-plus                  1.26.1
protobuf                    6.31.0
psutil                      7.0.0
py-cpuinfo                  9.0.0
pyasn1                      0.6.1
pyasn1_modules              0.4.2
pycparser                   2.22
pydantic                    2.11.4
pydantic_core               2.33.2
Pygments                    2.19.1
PyJWT                       2.10.1
PyNaCl                      1.5.0
pyOpenSSL                   24.3.0
pyparsing                   3.2.3
python-dateutil             2.9.0.post0
python-snappy               0.7.3
pytorch-ranger              0.1.1
pytz                        2025.2
PyYAML                      6.0.2
questionary                 2.1.0
regex                       2024.11.6
requests                    2.32.3
rich                        14.0.0
rsa                         4.9.1
ruamel.yaml                 0.18.10
ruamel.yaml.clib            0.2.12
s3transfer                  0.12.0
safetensors                 0.5.3
sentry-sdk                  2.28.0
setproctitle                1.3.6
setuptools                  78.1.1
six                         1.17.0
smmap                       5.0.2
sniffio                     1.3.1
sympy                       1.14.0
tokenizers                  0.19.1
torch                       2.4.0a0+gitd990dad
torch-optimi                0.2.1
torch-optimizer             0.3.0
torchmetrics                1.4.0.post0
torchvision                 0.19.0
tqdm                        4.67.1
transformers                4.44.1
triton                      3.0.0              1
typing_extensions           4.13.2
typing-inspection           0.4.0
tzdata                      2024.1
urllib3                     2.4.0
validators                  0.35.0
wandb                       0.19.11
wcwidth                     0.2.13
websockets                  15.0.1
wheel                       0.45.1
xxhash                      3.5.0
yarl                        1.20.0
zipp                        3.21.0
zstandard                   0.23.0
zstd                        1.5.7.0

My gut feelings says it might be something when loading the dataset for the seconds time, it hangs or has trouble finding where it left of?

The dataset is non-compressed mds shards with tokenized uint16 numpy arrays with input_ids.

timpal0l avatar May 20 '25 21:05 timpal0l

May not be able to help in great detail since its a setup we don't do much, but I'm guessing you are using a persistent file system? You can try running https://github.com/mosaicml/streaming/blob/e764d0c54667ae3306c9d9d780d571720d9d56e9/streaming/base/util.py#L169 prior to attempting to load any data. Might help if its streaming dataset loading that is the issue.

dakinggg avatar May 20 '25 23:05 dakinggg

@dakinggg Thanks for your reply! Am not using streaming. This is my dataloader:

train_loader:
  name: text
  dataset:
    local: ${data_local}
    remote: ${data_remote}
    split:
    tokenizer_name: ${tokenizer_name}
    max_seq_len: ${max_seq_len}
    shuffle: true
    mlm_probability: ${mlm_probability}
    streaming: false
  drop_last: true
  num_workers: 24
  sequence_packing: true
  persistent_workers: false
  pin_memory: false

timpal0l avatar May 21 '25 20:05 timpal0l

Not sure what codebase you're using, but that yaml looks like the LLM Foundry yamls we use with streaming, so I think you might be :)

dakinggg avatar May 21 '25 21:05 dakinggg

@dakinggg even with streaming: false?

I tried setting streaming: true for a different training run, where i had to turn of sequence packing, that yielded way lower tokens/s.

But pretty sure this one is not using streaming?

timpal0l avatar May 21 '25 21:05 timpal0l

Ah sorry, I don't know what codebase you are using, so I don't know how the yaml translates into code.

dakinggg avatar May 21 '25 22:05 dakinggg

Ah sorry, I don't know what codebase you are using, so I don't know how the yaml translates into code.

Aha, I am launching it with composer main.py trainer.yaml

With: https://github.com/AnswerDotAI/ModernBERT/tree/pretraining_documentation

I dont think it is streaming, but I did add the clean_stale_shared_memory() in my main script just in case. No difference.

timpal0l avatar May 21 '25 22:05 timpal0l

Ah, it might be an issue in the answerdotai library then. I'd recommend adding lots of logging to see where its getting stuck, and probably opening an issue on that repo :)

dakinggg avatar May 21 '25 22:05 dakinggg

Hi @timpal0l, I’m encountering the same issue when resuming training. Resuming after a few iterations works quickly, but trying to resume after just 2500 steps takes 5 minutes. I am using the NoStreamingDataset with a pre-tokenized dataset. Did you find any way to speed up the resume process?

dtamayo-nlp avatar Oct 08 '25 07:10 dtamayo-nlp

Hi @timpal0l, I’m encountering the same issue when resuming training. Resuming after a few iterations works quickly, but trying to resume after just 2500 steps takes 5 minutes. I am using the NoStreamingDataset with a pre-tokenized dataset. Did you find any way to speed up the resume process?

@dtamayo-nlp Hi, yes. For me the solution was to use streaming. So i re-tokenized the whole dataset and packed it manually, and used a streaming dataset.

timpal0l avatar Nov 05 '25 09:11 timpal0l

Thanks! It also worked for me.

dtamayo-nlp avatar Nov 11 '25 10:11 dtamayo-nlp