OLMo icon indicating copy to clipboard operation
OLMo copied to clipboard

High CrossEntropy and Z Loss variance after loading from checkpoint

Open abhijangda opened this issue 11 months ago • 1 comments

🐛 Describe the bug

I have been playing with configs/official-1124/OLMo-7B-stage1.yaml and training using the dataset in the YAML file. Unfortunately, I have found a strange issue. After loading from a checkpoint the variance in Cross Entropy and Z Loss has increased dramatically. For example, I ran first iteration till steps 5600 and then re ran training from a checkpoint of 4400. Here are Loss graphs from wandb:

image

You can see clearly that after step 4400 the variance is high.

I have tried this on following two systems and both shows the same problem.

  • 64 AMD MI300X with 8 nodes using ROCM 6.1, PyTorch 2.5.1, and Python 3.11
  • 64 NVIDIA A100 with 8 nodes using CUDA 12.4, PytTorch2.5.1, and Python 3.11

I have tried changing with heads and layers of OLMo-7B-stage1.yaml: 16 and 32 but both have same issues. I have been using OLMo Core checkpointer using the following method:

  1. First collect tensors of all nodes in model, train, and optim folder of checkpoints in a single folder accessible to all nodes.
  2. Then set --load_path= to the above folder containing all tensors.

Below is the config I used (I removed the dataset URLs):

run_name: OLMo2-7B-stage1
seed: 6198
dry_run: false

model:
  d_model: 4096
  n_heads: 32
  n_layers: 32
  mlp_hidden_size: 22016
  weight_tying: false
  alibi: false
  rope: true
  rope_theta: 500000
  flash_attention: true
  attention_dropout: 0.0
  include_bias: false
  block_type: sequential
  layer_norm_type: rms
  layer_norm_with_affine: true
  layer_norm_eps: 1e-6
  bias_for_layer_norm: false
  attention_layer_norm: true
  attention_layer_norm_with_affine: true
  norm_after: true
  activation_type: swiglu
  residual_dropout: 0.0
  embedding_dropout: 0.0
  max_sequence_length: 4096
  vocab_size: 100278
  embedding_size: 100352
  eos_token_id: 100257
  pad_token_id: 100277
  init_device: meta
  init_fn: normal
  init_std: 0.02
  init_cutoff_factor: 3

softmax_auxiliary_loss: true
auxiliary_loss_multiplier: 1e-5
fused_loss: true

compile: null

wandb:
  project: "llm-kron"
  entity: "abhijangda-microsoft"
  log_interval: 1
  group: "7B"

optimizer:
  name: adamw
  learning_rate: 3.0e-4
  weight_decay: 0.1
  eps: 1e-8
  decay_norm_and_bias: true
  decay_embeddings: false
  betas:
  - 0.9
  - 0.95
  metrics_log_interval: 1

scheduler:
  name: cosine_with_warmup
  units: tokens
  t_warmup: 8388608000
  t_max: 5e12
  alpha_f: 0.1
  warmup_min_lr: 0.0

tokenizer:
  identifier: tokenizers/allenai_dolma2.json
  truncate_direction: right

save_overwrite: false

save_interval: 1000
save_interval_ephemeral: 250
save_num_checkpoints_to_keep: -1
sharded_checkpointer: olmo_core

save_interval_unsharded: null
save_num_unsharded_checkpoints_to_keep: -1

load_path: null

max_duration: 1ep
global_train_batch_size: 1024
device_train_microbatch_size: 8

precision: amp_bf16

fsdp:
  wrapping_strategy: by_block_and_size
  precision: mixed

max_grad_norm: 1.0
max_grad_norm_ratio: null

speed_monitor:
  window_size: 1

gen1_gc_interval: 1

eval_interval: 1000
eval_subset_num_batches: -1
device_eval_batch_size: ${device_train_microbatch_size}
data:
  pad_direction: right
  # generate_doc_lengths: true
  num_workers: 32
  drop_last: true
  pin_memory: true
  prefetch_factor: 8
  persistent_workers: true
  memmap_dtype: uint32
  timeout: 0
  instance_filter:
    repetition_max_period: 13
    repetition_min_period: 1
    repetition_max_count: 32

Any idea what could be the issue here?

Versions

absl-py==2.1.0 accelerate==0.18.0 -e git+ssh://[email protected]/abhijangda/OLMo.git@77e47c6d84c018fc33a5eda086056c1402f74381#egg=ai2_olmo ai2-olmo-core==0.1.0 aiofiles==23.2.1 aiohappyeyeballs==2.4.3 aiohttp==3.11.3 aioshutil==1.5 aiosignal==1.3.1 annotated-types==0.7.0 antlr4-python3-runtime==4.9.3 anyio==4.6.2.post1 anykeystore==0.2 apex==0.1 appdirs==1.4.4 asttokens==2.4.1 astunparse==1.6.3 attrs==24.2.0 autocommand==2.2.2 backoff==2.2.1 backports.tarfile==1.2.0 beaker-gantry==1.10.0 beaker-py==1.32.3 beautifulsoup4==4.12.3 bitsandbytes==0.44.1 black==23.12.1 boltons==21.0.0 boto3==1.35.84 botocore==1.35.84 bracex==2.5.post1 Brotli==1.1.0 build==1.2.2.post1 cached_path==1.6.5 cachetools==5.5.0 certifi==2024.8.30 cffi==1.17.1 chardet==5.2.0 charset-normalizer==3.4.0 click==8.1.7 click-help-colors==0.9.4 click-option-group==0.5.6 cmake==3.31.0.1 codeshield==1.0.1 colorama==0.4.6 coloredlogs==15.0.1 contourpy==1.3.1 cryptacular==1.6.2 cryptography==43.0.3 cupy==13.3.0 cxxfilt==0.3.0 cycler==0.12.1 dataclasses-json==0.6.7 datasets==3.2.0 decorator==5.1.1 defusedxml==0.7.1 Deprecated==1.2.15 dill==0.3.6 distro==1.9.0 docker==7.1.0 docker-pycreds==0.4.0 docutils==0.21.2 effdet==0.4.1 einops==0.8.0 emoji==2.14.0 eval_type_backport==0.2.0 evaluate==0.4.3 exceptiongroup==1.2.2 executing==2.1.0 expecttest==0.2.1 face==24.0.0 fastapi==0.115.5 fastrlock==0.8.2 ffmpy==0.4.0 filelock==3.16.1 filetype==1.2.0 fire==0.7.0 flash_attn @ file:///home/aiscuser/ajangda/flash-attention flatbuffers==24.3.25 fonttools==4.55.0 frozenlist==1.5.0 fsspec==2023.9.2 ftfy==6.3.1 gitdb==4.0.11 GitPython==3.1.43 glom==22.1.0 google-api-core==2.23.0 google-auth==2.36.0 google-cloud-core==2.4.1 google-cloud-storage==2.19.0 google-cloud-vision==3.8.1 google-crc32c==1.6.0 google-resumable-media==2.7.2 googleapis-common-protos==1.66.0 gradio==5.6.0 gradio_client==1.4.3 greenlet==3.1.1 grpcio==1.68.0 grpcio-status==1.62.3 h11==0.14.0 httpcore==1.0.7 httpx==0.27.2 huggingface-hub==0.26.5 humanfriendly==10.0 hupper==1.12.1 hypothesis==6.119.2 idna==3.10 importlib_metadata==7.1.0 importlib_resources==6.4.0 inflate64==1.0.0 inflect==7.3.1 iniconfig==2.0.0 iopath==0.1.10 ipython==8.29.0 isort==5.12.0 jaraco.classes==3.4.0 jaraco.context==5.3.0 jaraco.functools==4.0.1 jaraco.text==3.12.1 jedi==0.19.2 jeepney==0.8.0 Jinja2==3.1.4 jmespath==1.0.1 joblib==1.4.2 jsonpatch==1.33 jsonpath-python==1.0.6 jsonpointer==3.0.0 jsonschema==4.23.0 jsonschema-specifications==2024.10.1 keyring==25.5.0 kiwisolver==1.4.7 langchain==0.2.17 langchain-community==0.2.19 langchain-core==0.2.43 langchain-openai==0.1.20 langchain-text-splitters==0.2.4 langdetect==1.0.9 langsmith==0.1.143 layoutparser==0.3.4 lightning-utilities==0.11.9 lintrunner==0.12.5 loralib==0.1.2 lxml==5.3.0 markdown-it-py==3.0.0 MarkupSafe==2.1.5 marshmallow==3.23.1 matplotlib==3.9.2 matplotlib-inline==0.1.7 mdurl==0.1.2 more-itertools==10.3.0 mpi4py @ file:///work/ci_py311/mpi4py_1676858691457/work mpmath==1.3.0 mscclpp @ file:///home/ajangda/mscclpp msgspec==0.18.6 multidict==6.1.0 multiprocess==0.70.14 multivolumefile==0.2.3 mypy==1.3.0 mypy-extensions==1.0.0 necessary==0.4.3 nest-asyncio==1.6.0 netifaces==0.11.0 networkx==3.4.2 nh3==0.2.20 ninja==1.11.1.1 nltk==3.9.1 numpy==1.26.4 nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==9.1.0.70 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-nccl-cu12==2.21.5 nvidia-nvjitlink-cu12==12.4.127 nvidia-nvtx-cu12==12.1.105 oauthlib==3.2.2 omegaconf==2.3.0 onnx==1.17.0 onnxruntime==1.20.0 openai==1.39.0 opencv-python==4.10.0.84 opentelemetry-api==1.25.0 opentelemetry-exporter-otlp-proto-common==1.25.0 opentelemetry-exporter-otlp-proto-http==1.25.0 opentelemetry-instrumentation==0.46b0 opentelemetry-instrumentation-requests==0.46b0 opentelemetry-proto==1.25.0 opentelemetry-sdk==1.25.0 opentelemetry-semantic-conventions==0.46b0 opentelemetry-util-http==0.46b0 optimum==1.23.3 optree==0.13.1 ordered-set==4.1.0 orjson==3.10.11 packaging==24.2 pandas==2.2.3 parso==0.8.4 PasteDeploy==3.1.0 pathspec==0.12.1 pbkdf2==1.3 pdf2image==1.17.0 pdfminer.six==20231228 pdfplumber==0.11.4 peewee==3.17.8 peft==0.13.2 petname==2.6 pexpect==4.9.0 pi_heif==0.20.0 pikepdf==9.4.2 pillow==11.0.0 pkginfo==1.12.0 plaster==1.1.2 plaster-pastedeploy==1.0.1 platformdirs==4.2.2 pluggy==1.5.0 portalocker==3.0.0 prettytable==3.12.0 prompt_toolkit==3.0.48 propcache==0.2.0 proto-plus==1.25.0 protobuf==4.25.5 psutil==6.1.0 ptyprocess==0.7.0 pure_eval==0.2.3 py7zr==0.22.0 pyarrow==18.0.0 pyasn1==0.6.1 pyasn1_modules==0.4.1 pybcj==1.0.2 pybind11==2.13.6 pybind11_global==2.13.6 pycocotools==2.0.8 pycparser==2.22 pycryptodomex==3.21.0 pydantic==2.9.2 pydantic_core==2.23.4 pydub==0.25.1 pyfastkron @ file:///home/aiscuser/ajangda/OLMo/pyfastkron-1.0.1-py3-none-any.whl#sha256=600f33c84967e12106e7e2b25f583422bf4a1a1f8dc887b5e8df54fa9bba2082 Pygments==2.18.0 pyparsing==3.2.0 pypdf==5.1.0 pypdfium2==4.30.0 pyppmd==1.1.0 pyproject_hooks==1.2.0 pyramid==2.0.2 pyramid-mailer==0.15.1 pytest==8.3.4 pytest-sphinx==0.6.3 python-dateutil==2.8.2 python-iso639==2024.10.22 python-magic==0.4.27 python-multipart==0.0.12 python3-openid==3.2.0 pytorch-triton-rocm==3.1.0 pytz==2024.2 PyYAML==6.0.1 pyzstd==0.16.2 RapidFuzz==3.10.1 readme_renderer==44.0 referencing==0.35.1 regex==2024.11.6 repoze.sendmail==4.4.1 requests==2.32.3 requests-oauthlib==2.0.0 requests-toolbelt==1.0.0 requirements-parser==0.11.0 responses==0.18.0 rfc3986==2.0.0 rich==13.5.3 rouge_score==0.1.2 rpds-py==0.21.0 rsa==4.9 ruamel.yaml==0.17.40 ruamel.yaml.clib==0.2.12 ruff==0.7.4 s3transfer==0.10.4 safehttpx==0.1.1 safetensors==0.4.5 scikit-learn==1.5.2 scipy==1.14.1 SecretStorage==3.3.3 semantic-version==2.10.0 semgrep==1.96.0 sentence-transformers==3.3.1 sentencepiece==0.2.0 sentry-sdk==2.19.2 setproctitle==1.3.4 shellingham==1.5.4 six==1.16.0 smart-open==7.1.0 smashed==0.21.5 smmap==5.0.1 sniffio==1.3.1 sortedcontainers==2.4.0 soupsieve==2.6 SQLAlchemy==2.0.36 stack-data==0.6.3 starlette==0.41.3 sympy==1.13.1 tabulate==0.9.0 tenacity==8.5.0 termcolor==2.5.0 texttable==1.7.0 threadpoolctl==3.5.0 tiktoken==0.8.0 timm==1.0.11 tokenize_rt==6.1.0 tokenizers==0.13.3 tomli==2.0.1 tomlkit==0.12.0 torch==2.5.1+rocm6.1 torchaudio==2.5.1+rocm6.1 torchmetrics==1.6.0 torchvision==0.20.1+rocm6.1 tqdm==4.67.1 traitlets==5.14.3 transaction==5.0 transformers==4.28.1 translationstring==1.4 triton==3.1.0 trouting==0.3.3 twine==6.0.1 typeguard==4.3.0 typer==0.13.1 types-dataclasses==0.6.6 types-setuptools==75.6.0.20241126 typing-inspect==0.9.0 typing_extensions==4.12.2 tzdata==2024.2 unstructured==0.15.8 unstructured-client==0.27.0 unstructured-inference==0.7.36 unstructured.pytesseract==0.3.13 urllib3==2.2.3 uvicorn==0.32.0 velruse==1.1.1 venusian==3.1.1 wandb==0.19.1 wcmatch==8.5.2 wcwidth==0.2.13 WebOb==1.8.9 websockets==12.0 wrapt==1.16.0 WTForms==3.2.1 wtforms-recaptcha==0.3.2 xxhash==3.5.0 yarl==1.17.2 zipp==3.19.2 zope.deprecation==5.0 zope.interface==7.2 zope.sqlalchemy==3.1

abhijangda avatar Jan 06 '25 18:01 abhijangda

Can you link me to the wandb? This might be a graphing issue with wandb. Wandb will use different sampling depending on how long the run is. Shorter runs appear noisier than longer ones.

dirkgr avatar Feb 06 '25 00:02 dirkgr

Hi, thanks again for the inquiry! We’re currently working on closing out old tickets, so we’re closing this out for now, but if you require a follow-up response, please re-open and we will get back to you!

baileykuehl avatar Jul 01 '25 17:07 baileykuehl