High CrossEntropy and Z Loss variance after loading from checkpoint
🐛 Describe the bug
I have been playing with configs/official-1124/OLMo-7B-stage1.yaml and training using the dataset in the YAML file. Unfortunately, I have found a strange issue. After loading from a checkpoint the variance in Cross Entropy and Z Loss has increased dramatically. For example, I ran first iteration till steps 5600 and then re ran training from a checkpoint of 4400. Here are Loss graphs from wandb:
You can see clearly that after step 4400 the variance is high.
I have tried this on following two systems and both shows the same problem.
- 64 AMD MI300X with 8 nodes using ROCM 6.1, PyTorch 2.5.1, and Python 3.11
- 64 NVIDIA A100 with 8 nodes using CUDA 12.4, PytTorch2.5.1, and Python 3.11
I have tried changing with heads and layers of OLMo-7B-stage1.yaml: 16 and 32 but both have same issues. I have been using OLMo Core checkpointer using the following method:
- First collect tensors of all nodes in
model,train, andoptimfolder of checkpoints in a single folder accessible to all nodes. - Then set
--load_path=to the above folder containing all tensors.
Below is the config I used (I removed the dataset URLs):
run_name: OLMo2-7B-stage1
seed: 6198
dry_run: false
model:
d_model: 4096
n_heads: 32
n_layers: 32
mlp_hidden_size: 22016
weight_tying: false
alibi: false
rope: true
rope_theta: 500000
flash_attention: true
attention_dropout: 0.0
include_bias: false
block_type: sequential
layer_norm_type: rms
layer_norm_with_affine: true
layer_norm_eps: 1e-6
bias_for_layer_norm: false
attention_layer_norm: true
attention_layer_norm_with_affine: true
norm_after: true
activation_type: swiglu
residual_dropout: 0.0
embedding_dropout: 0.0
max_sequence_length: 4096
vocab_size: 100278
embedding_size: 100352
eos_token_id: 100257
pad_token_id: 100277
init_device: meta
init_fn: normal
init_std: 0.02
init_cutoff_factor: 3
softmax_auxiliary_loss: true
auxiliary_loss_multiplier: 1e-5
fused_loss: true
compile: null
wandb:
project: "llm-kron"
entity: "abhijangda-microsoft"
log_interval: 1
group: "7B"
optimizer:
name: adamw
learning_rate: 3.0e-4
weight_decay: 0.1
eps: 1e-8
decay_norm_and_bias: true
decay_embeddings: false
betas:
- 0.9
- 0.95
metrics_log_interval: 1
scheduler:
name: cosine_with_warmup
units: tokens
t_warmup: 8388608000
t_max: 5e12
alpha_f: 0.1
warmup_min_lr: 0.0
tokenizer:
identifier: tokenizers/allenai_dolma2.json
truncate_direction: right
save_overwrite: false
save_interval: 1000
save_interval_ephemeral: 250
save_num_checkpoints_to_keep: -1
sharded_checkpointer: olmo_core
save_interval_unsharded: null
save_num_unsharded_checkpoints_to_keep: -1
load_path: null
max_duration: 1ep
global_train_batch_size: 1024
device_train_microbatch_size: 8
precision: amp_bf16
fsdp:
wrapping_strategy: by_block_and_size
precision: mixed
max_grad_norm: 1.0
max_grad_norm_ratio: null
speed_monitor:
window_size: 1
gen1_gc_interval: 1
eval_interval: 1000
eval_subset_num_batches: -1
device_eval_batch_size: ${device_train_microbatch_size}
data:
pad_direction: right
# generate_doc_lengths: true
num_workers: 32
drop_last: true
pin_memory: true
prefetch_factor: 8
persistent_workers: true
memmap_dtype: uint32
timeout: 0
instance_filter:
repetition_max_period: 13
repetition_min_period: 1
repetition_max_count: 32
Any idea what could be the issue here?
Versions
absl-py==2.1.0 accelerate==0.18.0 -e git+ssh://[email protected]/abhijangda/OLMo.git@77e47c6d84c018fc33a5eda086056c1402f74381#egg=ai2_olmo ai2-olmo-core==0.1.0 aiofiles==23.2.1 aiohappyeyeballs==2.4.3 aiohttp==3.11.3 aioshutil==1.5 aiosignal==1.3.1 annotated-types==0.7.0 antlr4-python3-runtime==4.9.3 anyio==4.6.2.post1 anykeystore==0.2 apex==0.1 appdirs==1.4.4 asttokens==2.4.1 astunparse==1.6.3 attrs==24.2.0 autocommand==2.2.2 backoff==2.2.1 backports.tarfile==1.2.0 beaker-gantry==1.10.0 beaker-py==1.32.3 beautifulsoup4==4.12.3 bitsandbytes==0.44.1 black==23.12.1 boltons==21.0.0 boto3==1.35.84 botocore==1.35.84 bracex==2.5.post1 Brotli==1.1.0 build==1.2.2.post1 cached_path==1.6.5 cachetools==5.5.0 certifi==2024.8.30 cffi==1.17.1 chardet==5.2.0 charset-normalizer==3.4.0 click==8.1.7 click-help-colors==0.9.4 click-option-group==0.5.6 cmake==3.31.0.1 codeshield==1.0.1 colorama==0.4.6 coloredlogs==15.0.1 contourpy==1.3.1 cryptacular==1.6.2 cryptography==43.0.3 cupy==13.3.0 cxxfilt==0.3.0 cycler==0.12.1 dataclasses-json==0.6.7 datasets==3.2.0 decorator==5.1.1 defusedxml==0.7.1 Deprecated==1.2.15 dill==0.3.6 distro==1.9.0 docker==7.1.0 docker-pycreds==0.4.0 docutils==0.21.2 effdet==0.4.1 einops==0.8.0 emoji==2.14.0 eval_type_backport==0.2.0 evaluate==0.4.3 exceptiongroup==1.2.2 executing==2.1.0 expecttest==0.2.1 face==24.0.0 fastapi==0.115.5 fastrlock==0.8.2 ffmpy==0.4.0 filelock==3.16.1 filetype==1.2.0 fire==0.7.0 flash_attn @ file:///home/aiscuser/ajangda/flash-attention flatbuffers==24.3.25 fonttools==4.55.0 frozenlist==1.5.0 fsspec==2023.9.2 ftfy==6.3.1 gitdb==4.0.11 GitPython==3.1.43 glom==22.1.0 google-api-core==2.23.0 google-auth==2.36.0 google-cloud-core==2.4.1 google-cloud-storage==2.19.0 google-cloud-vision==3.8.1 google-crc32c==1.6.0 google-resumable-media==2.7.2 googleapis-common-protos==1.66.0 gradio==5.6.0 gradio_client==1.4.3 greenlet==3.1.1 grpcio==1.68.0 grpcio-status==1.62.3 h11==0.14.0 httpcore==1.0.7 httpx==0.27.2 huggingface-hub==0.26.5 humanfriendly==10.0 hupper==1.12.1 hypothesis==6.119.2 idna==3.10 importlib_metadata==7.1.0 importlib_resources==6.4.0 inflate64==1.0.0 inflect==7.3.1 iniconfig==2.0.0 iopath==0.1.10 ipython==8.29.0 isort==5.12.0 jaraco.classes==3.4.0 jaraco.context==5.3.0 jaraco.functools==4.0.1 jaraco.text==3.12.1 jedi==0.19.2 jeepney==0.8.0 Jinja2==3.1.4 jmespath==1.0.1 joblib==1.4.2 jsonpatch==1.33 jsonpath-python==1.0.6 jsonpointer==3.0.0 jsonschema==4.23.0 jsonschema-specifications==2024.10.1 keyring==25.5.0 kiwisolver==1.4.7 langchain==0.2.17 langchain-community==0.2.19 langchain-core==0.2.43 langchain-openai==0.1.20 langchain-text-splitters==0.2.4 langdetect==1.0.9 langsmith==0.1.143 layoutparser==0.3.4 lightning-utilities==0.11.9 lintrunner==0.12.5 loralib==0.1.2 lxml==5.3.0 markdown-it-py==3.0.0 MarkupSafe==2.1.5 marshmallow==3.23.1 matplotlib==3.9.2 matplotlib-inline==0.1.7 mdurl==0.1.2 more-itertools==10.3.0 mpi4py @ file:///work/ci_py311/mpi4py_1676858691457/work mpmath==1.3.0 mscclpp @ file:///home/ajangda/mscclpp msgspec==0.18.6 multidict==6.1.0 multiprocess==0.70.14 multivolumefile==0.2.3 mypy==1.3.0 mypy-extensions==1.0.0 necessary==0.4.3 nest-asyncio==1.6.0 netifaces==0.11.0 networkx==3.4.2 nh3==0.2.20 ninja==1.11.1.1 nltk==3.9.1 numpy==1.26.4 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==9.1.0.70 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.21.5 nvidia-nvjitlink-cu12==12.4.127 nvidia-nvtx-cu12==12.1.105 oauthlib==3.2.2 omegaconf==2.3.0 onnx==1.17.0 onnxruntime==1.20.0 openai==1.39.0 opencv-python==4.10.0.84 opentelemetry-api==1.25.0 opentelemetry-exporter-otlp-proto-common==1.25.0 opentelemetry-exporter-otlp-proto-http==1.25.0 opentelemetry-instrumentation==0.46b0 opentelemetry-instrumentation-requests==0.46b0 opentelemetry-proto==1.25.0 opentelemetry-sdk==1.25.0 opentelemetry-semantic-conventions==0.46b0 opentelemetry-util-http==0.46b0 optimum==1.23.3 optree==0.13.1 ordered-set==4.1.0 orjson==3.10.11 packaging==24.2 pandas==2.2.3 parso==0.8.4 PasteDeploy==3.1.0 pathspec==0.12.1 pbkdf2==1.3 pdf2image==1.17.0 pdfminer.six==20231228 pdfplumber==0.11.4 peewee==3.17.8 peft==0.13.2 petname==2.6 pexpect==4.9.0 pi_heif==0.20.0 pikepdf==9.4.2 pillow==11.0.0 pkginfo==1.12.0 plaster==1.1.2 plaster-pastedeploy==1.0.1 platformdirs==4.2.2 pluggy==1.5.0 portalocker==3.0.0 prettytable==3.12.0 prompt_toolkit==3.0.48 propcache==0.2.0 proto-plus==1.25.0 protobuf==4.25.5 psutil==6.1.0 ptyprocess==0.7.0 pure_eval==0.2.3 py7zr==0.22.0 pyarrow==18.0.0 pyasn1==0.6.1 pyasn1_modules==0.4.1 pybcj==1.0.2 pybind11==2.13.6 pybind11_global==2.13.6 pycocotools==2.0.8 pycparser==2.22 pycryptodomex==3.21.0 pydantic==2.9.2 pydantic_core==2.23.4 pydub==0.25.1 pyfastkron @ file:///home/aiscuser/ajangda/OLMo/pyfastkron-1.0.1-py3-none-any.whl#sha256=600f33c84967e12106e7e2b25f583422bf4a1a1f8dc887b5e8df54fa9bba2082 Pygments==2.18.0 pyparsing==3.2.0 pypdf==5.1.0 pypdfium2==4.30.0 pyppmd==1.1.0 pyproject_hooks==1.2.0 pyramid==2.0.2 pyramid-mailer==0.15.1 pytest==8.3.4 pytest-sphinx==0.6.3 python-dateutil==2.8.2 python-iso639==2024.10.22 python-magic==0.4.27 python-multipart==0.0.12 python3-openid==3.2.0 pytorch-triton-rocm==3.1.0 pytz==2024.2 PyYAML==6.0.1 pyzstd==0.16.2 RapidFuzz==3.10.1 readme_renderer==44.0 referencing==0.35.1 regex==2024.11.6 repoze.sendmail==4.4.1 requests==2.32.3 requests-oauthlib==2.0.0 requests-toolbelt==1.0.0 requirements-parser==0.11.0 responses==0.18.0 rfc3986==2.0.0 rich==13.5.3 rouge_score==0.1.2 rpds-py==0.21.0 rsa==4.9 ruamel.yaml==0.17.40 ruamel.yaml.clib==0.2.12 ruff==0.7.4 s3transfer==0.10.4 safehttpx==0.1.1 safetensors==0.4.5 scikit-learn==1.5.2 scipy==1.14.1 SecretStorage==3.3.3 semantic-version==2.10.0 semgrep==1.96.0 sentence-transformers==3.3.1 sentencepiece==0.2.0 sentry-sdk==2.19.2 setproctitle==1.3.4 shellingham==1.5.4 six==1.16.0 smart-open==7.1.0 smashed==0.21.5 smmap==5.0.1 sniffio==1.3.1 sortedcontainers==2.4.0 soupsieve==2.6 SQLAlchemy==2.0.36 stack-data==0.6.3 starlette==0.41.3 sympy==1.13.1 tabulate==0.9.0 tenacity==8.5.0 termcolor==2.5.0 texttable==1.7.0 threadpoolctl==3.5.0 tiktoken==0.8.0 timm==1.0.11 tokenize_rt==6.1.0 tokenizers==0.13.3 tomli==2.0.1 tomlkit==0.12.0 torch==2.5.1+rocm6.1 torchaudio==2.5.1+rocm6.1 torchmetrics==1.6.0 torchvision==0.20.1+rocm6.1 tqdm==4.67.1 traitlets==5.14.3 transaction==5.0 transformers==4.28.1 translationstring==1.4 triton==3.1.0 trouting==0.3.3 twine==6.0.1 typeguard==4.3.0 typer==0.13.1 types-dataclasses==0.6.6 types-setuptools==75.6.0.20241126 typing-inspect==0.9.0 typing_extensions==4.12.2 tzdata==2024.2 unstructured==0.15.8 unstructured-client==0.27.0 unstructured-inference==0.7.36 unstructured.pytesseract==0.3.13 urllib3==2.2.3 uvicorn==0.32.0 velruse==1.1.1 venusian==3.1.1 wandb==0.19.1 wcmatch==8.5.2 wcwidth==0.2.13 WebOb==1.8.9 websockets==12.0 wrapt==1.16.0 WTForms==3.2.1 wtforms-recaptcha==0.3.2 xxhash==3.5.0 yarl==1.17.2 zipp==3.19.2 zope.deprecation==5.0 zope.interface==7.2 zope.sqlalchemy==3.1
Can you link me to the wandb? This might be a graphing issue with wandb. Wandb will use different sampling depending on how long the run is. Shorter runs appear noisier than longer ones.
Hi, thanks again for the inquiry! We’re currently working on closing out old tickets, so we’re closing this out for now, but if you require a follow-up response, please re-open and we will get back to you!