llm-foundry icon indicating copy to clipboard operation
llm-foundry copied to clipboard

Torch2 updt

Open vchiley opened this issue 2 years ago • 2 comments

uses https://github.com/mosaicml/llm-foundry/pull/147 as a springboard to updt torch

In interactive instance, I installed torch2 req and everything works fine

125M models was getting good (the same) MFU from the same exact config in both torch1.13 and torch2

Note: torch2 version pip list has both triton version:

torch                  2.0.1+cu118

triton                 2.0.0
triton-pre-mlir        2.0.0

doesn't seem to matter

Note: this does not use torch.compile() (but there is no reason it shouldn't)

Note: flash-attn is still installed. xentropy-cuda-lib is also still installed; I'm not setting loss_fn so mpt defaults to using fused_crossentropy for both settings.

Biggest low probability risk: this old version of triton does not compile / work for H100s... 👀 Risk: triton_pre_mlir has no support and will never be updated.

~~Still need to test at scale / convergence~~ see torch2 vs torch1.13 produce the same results here

vchiley avatar May 16 '23 22:05 vchiley

~~blocking composer pr: https://github.com/mosaicml/composer/pull/2229 (waiting for new composer img)~~ merged

vchiley avatar May 17 '23 01:05 vchiley

Issue: using Torch2 checkpointing caused the run to crash

 Eval metrics/eval/LanguagePerplexity: 60.0989
Traceback (most recent call last):
  File "/llm-foundry/scripts/train/train.py", line 254, in <module>
    main(cfg)
  File "/llm-foundry/scripts/train/train.py", line 243, in main
    trainer.fit()
  File "/usr/lib/python3/dist-packages/composer/trainer/trainer.py", line 1766, in fit
    self._train_loop()
  File "/usr/lib/python3/dist-packages/composer/trainer/trainer.py", line 1996, in _train_loop
    self.engine.run_event(Event.BATCH_CHECKPOINT)
  File "/usr/lib/python3/dist-packages/composer/core/engine.py", line 293, in run_event
    self._run_nonlogger_callbacks(event)
  File "/usr/lib/python3/dist-packages/composer/core/engine.py", line 475, in _run_nonlogger_callbacks
    self._run_callbacks(event, callbacks)
  File "/usr/lib/python3/dist-packages/composer/core/engine.py", line 467, in _run_callbacks
    cb.run_event(event, self.state, self.logger)
  File "/usr/lib/python3/dist-packages/composer/core/callback.py", line 96, in run_event
    return event_cb(state, logger)
  File "/usr/lib/python3/dist-packages/composer/callbacks/checkpoint_saver.py", line 346, in batch_checkpoint
    self._save_checkpoint(
  File "/usr/lib/python3/dist-packages/composer/callbacks/checkpoint_saver.py", line 384, in _save_checkpoint
    saved_path = checkpoint.save_checkpoint(
  File "/usr/lib/python3/dist-packages/composer/utils/checkpoint.py", line 518, in save_checkpoint
    'state': state.state_dict(),
  File "/usr/lib/python3/dist-packages/composer/core/state.py", line 802, in state_dict
    fsdp_get_optim_state_dict(self.model, optimizer, state_dict_type=self.fsdp_state_dict_type)
  File "/usr/lib/python3/dist-packages/composer/core/state.py", line 127, in fsdp_get_optim_state_dict
    optim_state_dict = FSDP.optim_state_dict(model, optim)  # type: ignore
  File "/usr/lib/python3/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1753, in optim_state_dict
    return FullyShardedDataParallel._optim_state_dict_impl(
  File "/usr/lib/python3/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1154, in _optim_state_dict_impl
    return _optim_state_dict(
  File "/usr/lib/python3/dist-packages/torch/distributed/fsdp/_optim_utils.py", line 1455, in _optim_state_dict
    _gather_orig_param_state(
  File "/usr/lib/python3/dist-packages/torch/distributed/fsdp/_optim_utils.py", line 1690, in _gather_orig_param_state
    gathered_state = _all_gather_optim_state(fsdp_state, optim_state)
  File "/usr/lib/python3/dist-packages/torch/distributed/fsdp/_optim_utils.py", line 1637, in _all_gather_optim_state
    for name, non_tensor_value in object_state.non_tensors.items():
AttributeError: 'int' object has no attribute 'items'

everything was fine without ckpt

vchiley avatar May 17 '23 05:05 vchiley

With parameters['fsdp_config']['use_orig_params'] = False ckpt is not broken and everything runs fine.

vchiley avatar May 17 '23 23:05 vchiley

Running from this branch I still hit the same ckpt issue (with attn_impl: torch) when running using mosaicml/pytorch:2.0.0_cu117-python3.10-ubuntu20.04 as the base img. If I run in interactive mode and install torch from pytorch manually, this issue does not exist.

The issue isn't with this branch / PR, it seems to be an issue with the composer img (https://github.com/mosaicml/composer/issues/2231)

cc @eracah @bcui19 @dakinggg

vchiley avatar May 18 '23 16:05 vchiley

@eracah identified the issue as an issue with how optimizers were implemented in composer

vchiley avatar May 19 '23 15:05 vchiley

Can we make sure that CI (github) tests are run on torch 2 images as part of this change?

sashaDoubov avatar May 19 '23 17:05 sashaDoubov

Can we make sure that CI (github) tests are run on torch 2 images as part of this change?

+1. Take a look at what we do in Composer

mvpatel2000 avatar May 19 '23 17:05 mvpatel2000

closing in favor of https://github.com/mosaicml/llm-foundry/pull/178

vchiley avatar May 19 '23 22:05 vchiley

thanks for this, mosaicml/pytorch:2.0.0_cu117-python3.10-ubuntu20.04 this image already have torch 2.0 and new triton mlir? seems like now its will be no problems to launch mpt and another models with triton flash attention?

germanjke avatar May 19 '23 22:05 germanjke

@germanjke I'd use the updated mosaicml/pytorch:2.0.1_cu117-python3.10-ubuntu20.04 image. The triton installed with torch2+ has issues We therefore use a fork of triton for the triton impl we use (see here and here)

vchiley avatar May 19 '23 23:05 vchiley

@vchiley what's about training configs? do we need change attn_config: attn_impl: triton to triton_mlir?

germanjke avatar May 19 '23 23:05 germanjke

attn_config: attn_impl: triton will run using triton_pre_mlir; as noted the triton installed with torch2 has issues and will break attn_config: attn_impl: triton

Once triton2 is fixed and torch2 starts using the fixed version of triton, triton_pre_mlir will be removed and we will revert to using the triton installed with torch.

vchiley avatar May 20 '23 00:05 vchiley