llm-foundry
                                
                                
                                
                                    llm-foundry copied to clipboard
                            
                            
                            
                        Torch2 updt
uses https://github.com/mosaicml/llm-foundry/pull/147 as a springboard to updt torch
In interactive instance, I installed torch2 req and everything works fine
125M models was getting good (the same) MFU from the same exact config in both torch1.13 and torch2
Note: torch2 version pip list has both triton version:
torch                  2.0.1+cu118
triton                 2.0.0
triton-pre-mlir        2.0.0
doesn't seem to matter
Note: this does not use torch.compile() (but there is no reason it shouldn't)
Note: flash-attn is still installed. xentropy-cuda-lib is also still installed; I'm not setting loss_fn so mpt defaults to using fused_crossentropy for both settings.
Biggest low probability risk: this old version of triton does not compile / work for H100s... 👀
Risk: triton_pre_mlir has no support and will never be updated.
~~Still need to test at scale / convergence~~ see torch2 vs torch1.13 produce the same results here
~~blocking composer pr: https://github.com/mosaicml/composer/pull/2229 (waiting for new composer img)~~ merged
Issue: using Torch2 checkpointing caused the run to crash
 Eval metrics/eval/LanguagePerplexity: 60.0989
Traceback (most recent call last):
  File "/llm-foundry/scripts/train/train.py", line 254, in <module>
    main(cfg)
  File "/llm-foundry/scripts/train/train.py", line 243, in main
    trainer.fit()
  File "/usr/lib/python3/dist-packages/composer/trainer/trainer.py", line 1766, in fit
    self._train_loop()
  File "/usr/lib/python3/dist-packages/composer/trainer/trainer.py", line 1996, in _train_loop
    self.engine.run_event(Event.BATCH_CHECKPOINT)
  File "/usr/lib/python3/dist-packages/composer/core/engine.py", line 293, in run_event
    self._run_nonlogger_callbacks(event)
  File "/usr/lib/python3/dist-packages/composer/core/engine.py", line 475, in _run_nonlogger_callbacks
    self._run_callbacks(event, callbacks)
  File "/usr/lib/python3/dist-packages/composer/core/engine.py", line 467, in _run_callbacks
    cb.run_event(event, self.state, self.logger)
  File "/usr/lib/python3/dist-packages/composer/core/callback.py", line 96, in run_event
    return event_cb(state, logger)
  File "/usr/lib/python3/dist-packages/composer/callbacks/checkpoint_saver.py", line 346, in batch_checkpoint
    self._save_checkpoint(
  File "/usr/lib/python3/dist-packages/composer/callbacks/checkpoint_saver.py", line 384, in _save_checkpoint
    saved_path = checkpoint.save_checkpoint(
  File "/usr/lib/python3/dist-packages/composer/utils/checkpoint.py", line 518, in save_checkpoint
    'state': state.state_dict(),
  File "/usr/lib/python3/dist-packages/composer/core/state.py", line 802, in state_dict
    fsdp_get_optim_state_dict(self.model, optimizer, state_dict_type=self.fsdp_state_dict_type)
  File "/usr/lib/python3/dist-packages/composer/core/state.py", line 127, in fsdp_get_optim_state_dict
    optim_state_dict = FSDP.optim_state_dict(model, optim)  # type: ignore
  File "/usr/lib/python3/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1753, in optim_state_dict
    return FullyShardedDataParallel._optim_state_dict_impl(
  File "/usr/lib/python3/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1154, in _optim_state_dict_impl
    return _optim_state_dict(
  File "/usr/lib/python3/dist-packages/torch/distributed/fsdp/_optim_utils.py", line 1455, in _optim_state_dict
    _gather_orig_param_state(
  File "/usr/lib/python3/dist-packages/torch/distributed/fsdp/_optim_utils.py", line 1690, in _gather_orig_param_state
    gathered_state = _all_gather_optim_state(fsdp_state, optim_state)
  File "/usr/lib/python3/dist-packages/torch/distributed/fsdp/_optim_utils.py", line 1637, in _all_gather_optim_state
    for name, non_tensor_value in object_state.non_tensors.items():
AttributeError: 'int' object has no attribute 'items'
                                    
                                    
                                    
                                
With parameters['fsdp_config']['use_orig_params'] = False ckpt is not broken and everything runs fine.
Running from this branch I still hit the same ckpt issue (with attn_impl: torch) when running using mosaicml/pytorch:2.0.0_cu117-python3.10-ubuntu20.04 as the base img.
If I run in interactive mode and install torch from pytorch manually, this issue does not exist.
The issue isn't with this branch / PR, it seems to be an issue with the composer img (https://github.com/mosaicml/composer/issues/2231)
cc @eracah @bcui19 @dakinggg
@eracah identified the issue as an issue with how optimizers were implemented in composer
Can we make sure that CI (github) tests are run on torch 2 images as part of this change?
Can we make sure that CI (github) tests are run on torch 2 images as part of this change?
+1. Take a look at what we do in Composer
closing in favor of https://github.com/mosaicml/llm-foundry/pull/178
thanks for this,
mosaicml/pytorch:2.0.0_cu117-python3.10-ubuntu20.04 this image already have torch 2.0 and new triton mlir?
seems like now its will be no problems to launch mpt and another models with triton flash attention?
@germanjke
I'd use the updated mosaicml/pytorch:2.0.1_cu117-python3.10-ubuntu20.04 image.
The triton installed with torch2+ has issues
We therefore use a fork of triton for the triton impl we use (see here and here)
@vchiley
what's about training configs?
do we need change
attn_config: attn_impl: triton
to triton_mlir?
attn_config: attn_impl: triton will run using triton_pre_mlir; as noted the triton installed with torch2 has issues and will break attn_config: attn_impl: triton
Once triton2 is fixed and torch2 starts using the fixed version of triton, triton_pre_mlir will be removed and we will revert to using the triton installed with torch.