composer
composer copied to clipboard
Stuck when continuing training with auto_resume: true
Pretraining a large modernbert model on 1.1 trillion tokens
I am training on a multi-node setup with the following yaml. And it works fine, except that when I run it for the second time with auto_resume: true (to load the latest checkpoint, it just loads forever and nothing happens. I can see it allocates around 7gb vram on the gpu nodes. It allocates 29gb vram when running succesfully.
I only see:
Starting training...
Training yaml:
# Data paths
data_local: /project/scratch/p200667/dataset
data_remote: # If blank, files must be present in data_local
# Sequence & tokenizer
max_seq_len: 8192
tokenizer_name: answerdotai/ModernBERT-large
mlm_probability: 0.3
count_padding_tokens: false
# Run Name
run_name: modernbert-large-pretrain
# Model
model:
name: flex_bert
pretrained_model_name: bert-base-uncased # has to be set to bert-base-uncased legacy reasons
tokenizer_name: ${tokenizer_name}
disable_train_metrics: true # save some time by not computing metrics on the training set
model_config:
vocab_size: 50368
init_method: full_megatron
num_hidden_layers: 28
hidden_size: 1024
intermediate_size: 2624
num_attention_heads: 16 # to have head size of 64
attention_layer: rope
attention_probs_dropout_prob: 0.0
attn_out_bias: false
attn_out_dropout_prob: 0.1
attn_qkv_bias: false
bert_layer: prenorm
embed_dropout_prob: 0.0
embed_norm: true
final_norm: true
skip_first_prenorm: true
embedding_layer: sans_pos
loss_function: fa_cross_entropy
loss_kwargs:
reduction: mean
mlp_dropout_prob: 0.0
mlp_in_bias: false
mlp_layer: glu
mlp_out_bias: false
normalization: layernorm
norm_kwargs:
eps: 1e-5
bias: false
hidden_act: gelu
head_pred_act: gelu
activation_function: gelu # better safe than sorry
padding: unpadded
rotary_emb_dim: null
rotary_emb_base: 160000.0
rotary_emb_scale_base: null
rotary_emb_interleaved: false
local_attn_rotary_emb_base: 10000.0
local_attn_rotary_emb_dim: null
allow_embedding_resizing: true
sliding_window: 128
global_attn_every_n_layers: 3
unpad_embeddings: true
compile_model: true
masked_prediction: true
# Dataloaders
train_loader:
name: text
dataset:
local: ${data_local}
remote: ${data_remote}
split:
tokenizer_name: ${tokenizer_name}
max_seq_len: ${max_seq_len}
shuffle: true
mlm_probability: ${mlm_probability}
streaming: false
drop_last: true
num_workers: 24
sequence_packing: true
persistent_workers: false
pin_memory: false
# Optimization
scheduler:
name: cosine_with_warmup
t_warmup: 11_985_103_472tok # 1% of total ds
t_max: ${max_duration}
alpha_f: 0.001
optimizer:
name: decoupled_stableadamw
lr: 2e-4
betas:
- 0.9
- 0.98
eps: 1.0e-06
weight_decay: 1.0e-5
filter_bias_norm_wd: true
log_grad_norm: true
# Training duration & batch sizes
max_duration: 1_198_510_347_252tok
eval_interval: 0
global_train_batch_size: 256
global_eval_batch_size: 256
# System settings
seed: 420
device_eval_batch_size: 2
device_train_microbatch_size: 2
precision: amp_bf16
# Logging & callbacks
progress_bar: false
log_to_console: true
console_log_interval: 10ba
callbacks:
speed_monitor:
window_size: 10
lr_monitor: {}
scheduled_gc: {}
log_grad_norm:
batch_log_interval: 10
packing_efficiency:
log_interval: 10
# Checkpointing
save_interval: 5000ba
save_num_checkpoints_to_keep: 3
save_folder: checkpoints/{run_name}
#load_path: checkpoints/${run_name}/ep0-ba30000-rank0.pt
autoresume: true
#reset_time: false
#loggers:
# wandb:
# project: modernbert-large-pretrain
# entity: tim
I launch it with a bash script on a hpc cluster using slurm on 16 nodes:
#!/bin/bash -l
#SBATCH -A p200XXX
#SBATCH -p gpu
#SBATCH --qos=default
#SBATCH --nodes=16 # total nodes
#SBATCH --ntasks-per-node=1 # one Slurm task per node
#SBATCH --gres=gpu:4 # 4 GPUs per node
#SBATCH --cpus-per-task=128
#SBATCH --time=10:00:00
#SBATCH --job-name=modernbert_ddp
#SBATCH --output=logs/modernbert_%j.out
#SBATCH --error=logs/modernbert_%j.err
#SBATCH --exclusive # no other jobs on these nodes
#SBATCH --wait-all-nodes=1 # start only once all nodes are up
set -eo pipefail
ulimit -n 8192
#module purge
module load NCCL/2.22.3-GCCcore-13.3.0-CUDA-12.6.0
# Ensure we use the system NCCL
export LD_LIBRARY_PATH=$EBROOTNCCL/lib:$LD_LIBRARY_PATH
source /project/scratch/p200667/miniconda/etc/profile.d/conda.sh
conda activate bert2026
# Logging / debug
export LOGLEVEL=INFO
export NCCL_DEBUG=TRACE
export TORCH_CPP_LG_LEVEL=INFO
# export NCCL_TIMEOUT=900
# NCCL / IB-verbs tunables
export NCCL_SOCKET_IFNAME=ib0
# export NCCL_IB_HCA=mlx5_0
export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
export OMP_NUM_THREADS=8
# Gather node list and compute world‐size
NODES=( $(scontrol show hostnames $SLURM_JOB_NODELIST) )
NNODES=${#NODES[@]}
HEAD_NODE=${NODES[0]}
# Derive the IP of the head node on ib0
MASTER_ADDR=$( srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address )
MASTER_PORT=$(( RANDOM % 10000 + 20000 )) # random port in [20000–29999]
# DDP parameters
NPROC=4
WORLD_SIZE=$(( NNODES * NPROC ))
CONFIG=/project/scratch/p200667/timpal0l/ModernBERT/training/modernbert-large-learning-rate-decay.yaml
echo "=== DDP CONFIGURATION ==="
echo " NODES: ${NODES[*]}"
echo " NNODES: $NNODES"
echo " HEAD_NODE: $HEAD_NODE"
echo " MASTER_ADDR: $MASTER_ADDR"
echo " MASTER_PORT: $MASTER_PORT"
echo " NPROC per node: $NPROC"
echo " WORLD_SIZE: $WORLD_SIZE"
echo " CONFIG: $CONFIG"
echo "========================="
# Function to launch Composer on a given node‐rank
run_compose() {
local NODE=$1
local NODE_RANK=$2
echo ">>> Launching node_rank=$NODE_RANK on $NODE"
srun --nodelist=$NODE \
--ntasks=1 \
--cpus-per-task=128 \
--gres=gpu:4 \
composer \
--nproc $NPROC \
--world_size $WORLD_SIZE \
--node_rank $NODE_RANK \
--base_rank $(( NODE_RANK * NPROC )) \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT \
--stdout logs/modernbert_${SLURM_JOB_ID}_rank{rank}.out \
--stderr logs/modernbert_${SLURM_JOB_ID}_rank{rank}.err \
--verbose \
main.py "$CONFIG"
}
# Launch worker nodes (ranks 1 … NNODES–1) in the background
for (( NODE_RANK=1; NODE_RANK<NNODES; NODE_RANK++ )); do
run_compose "${NODES[$NODE_RANK]}" $NODE_RANK &
done
# Launch head node as rank 0 in the foreground
run_compose $HEAD_NODE 0
# Wait for all background launches to finish
wait
echo "All ranks finished."
The main.py
[u102342@mel2035 ~]$ pip list
annotated-types 0.7.0
antlr4-python3-runtime 4.9.3
anyio 4.9.0
azure-core 1.34.0
azure-identity 1.23.0
azure-storage-blob 12.25.1
azure-storage-file-datalake 12.20.0
backoff 2.2.1
bcrypt 4.3.0
boto3 1.38.17
botocore 1.38.17
Brotli 1.1.0
cachetools 5.5.2
certifi 2025.4.26
cffi 1.17.1
charset-normalizer 3.4.2
circuitbreaker 2.1.3
click 8.2.0
contourpy 1.3.2
coolname 2.2.0
cramjam 2.10.0
cryptography 44.0.3
cycler 0.12.1
docker-pycreds 0.4.0
einops 0.8.0
filelock 3.18.0
flash-attn 2.6.3
fonttools 4.58.0
fsspec 2025.3.2
gitdb 4.0.12
GitPython 3.1.44
google-api-core 2.25.0rc1
google-auth 2.40.1
google-cloud-core 2.4.3
google-cloud-storage 2.10.0
google-crc32c 1.7.1
google-resumable-media 2.7.2
googleapis-common-protos 1.70.0
gql 3.5.2
graphql-core 3.2.4
huggingface-hub 0.31.2
idna 3.10
importlib_metadata 8.4.0
isodate 0.7.2
Jinja2 3.1.6
jmespath 1.0.1
kiwisolver 1.4.8
lightning-utilities 0.14.3
llvmlite 0.43.0
markdown-it-py 3.0.0
MarkupSafe 3.0.2
matplotlib 3.10.3
mdurl 0.1.2
mosaicml 0.30.0
mosaicml-cli 0.7.2
mosaicml-streaming 0.7.6
mpmath 1.3.0
msal 1.32.3
msal-extensions 1.3.1
multidict 6.4.3
myquota 0.3.3
networkx 3.4.2
ninja 1.11.1.4
numba 0.60.0
numpy 2.0.2
oci 2.152.0
omegaconf 2.3.0
packaging 25.0
paramiko 3.5.1
pillow 11.2.1
pip 25.1
platformdirs 4.3.8
prompt_toolkit 3.0.51
propcache 0.3.1
proto-plus 1.26.1
protobuf 6.31.0
psutil 7.0.0
py-cpuinfo 9.0.0
pyasn1 0.6.1
pyasn1_modules 0.4.2
pycparser 2.22
pydantic 2.11.4
pydantic_core 2.33.2
Pygments 2.19.1
PyJWT 2.10.1
PyNaCl 1.5.0
pyOpenSSL 24.3.0
pyparsing 3.2.3
python-dateutil 2.9.0.post0
python-snappy 0.7.3
pytorch-ranger 0.1.1
pytz 2025.2
PyYAML 6.0.2
questionary 2.1.0
regex 2024.11.6
requests 2.32.3
rich 14.0.0
rsa 4.9.1
ruamel.yaml 0.18.10
ruamel.yaml.clib 0.2.12
s3transfer 0.12.0
safetensors 0.5.3
sentry-sdk 2.28.0
setproctitle 1.3.6
setuptools 78.1.1
six 1.17.0
smmap 5.0.2
sniffio 1.3.1
sympy 1.14.0
tokenizers 0.19.1
torch 2.4.0a0+gitd990dad
torch-optimi 0.2.1
torch-optimizer 0.3.0
torchmetrics 1.4.0.post0
torchvision 0.19.0
tqdm 4.67.1
transformers 4.44.1
triton 3.0.0 1
typing_extensions 4.13.2
typing-inspection 0.4.0
tzdata 2024.1
urllib3 2.4.0
validators 0.35.0
wandb 0.19.11
wcwidth 0.2.13
websockets 15.0.1
wheel 0.45.1
xxhash 3.5.0
yarl 1.20.0
zipp 3.21.0
zstandard 0.23.0
zstd 1.5.7.0
My gut feelings says it might be something when loading the dataset for the seconds time, it hangs or has trouble finding where it left of?
The dataset is non-compressed mds shards with tokenized uint16 numpy arrays with input_ids.
May not be able to help in great detail since its a setup we don't do much, but I'm guessing you are using a persistent file system? You can try running https://github.com/mosaicml/streaming/blob/e764d0c54667ae3306c9d9d780d571720d9d56e9/streaming/base/util.py#L169 prior to attempting to load any data. Might help if its streaming dataset loading that is the issue.
@dakinggg Thanks for your reply! Am not using streaming. This is my dataloader:
train_loader:
name: text
dataset:
local: ${data_local}
remote: ${data_remote}
split:
tokenizer_name: ${tokenizer_name}
max_seq_len: ${max_seq_len}
shuffle: true
mlm_probability: ${mlm_probability}
streaming: false
drop_last: true
num_workers: 24
sequence_packing: true
persistent_workers: false
pin_memory: false
Not sure what codebase you're using, but that yaml looks like the LLM Foundry yamls we use with streaming, so I think you might be :)
@dakinggg even with streaming: false?
I tried setting streaming: true for a different training run, where i had to turn of sequence packing, that yielded way lower tokens/s.
But pretty sure this one is not using streaming?
Ah sorry, I don't know what codebase you are using, so I don't know how the yaml translates into code.
Ah sorry, I don't know what codebase you are using, so I don't know how the yaml translates into code.
Aha, I am launching it with composer main.py trainer.yaml
With: https://github.com/AnswerDotAI/ModernBERT/tree/pretraining_documentation
I dont think it is streaming, but I did add the clean_stale_shared_memory() in my main script just in case. No difference.
Ah, it might be an issue in the answerdotai library then. I'd recommend adding lots of logging to see where its getting stuck, and probably opening an issue on that repo :)
Hi @timpal0l, I’m encountering the same issue when resuming training. Resuming after a few iterations works quickly, but trying to resume after just 2500 steps takes 5 minutes. I am using the NoStreamingDataset with a pre-tokenized dataset. Did you find any way to speed up the resume process?
Hi @timpal0l, I’m encountering the same issue when resuming training. Resuming after a few iterations works quickly, but trying to resume after just 2500 steps takes 5 minutes. I am using the
NoStreamingDatasetwith a pre-tokenized dataset. Did you find any way to speed up the resume process?
@dtamayo-nlp Hi, yes. For me the solution was to use streaming. So i re-tokenized the whole dataset and packed it manually, and used a streaming dataset.
Thanks! It also worked for me.