metaseq icon indicating copy to clipboard operation
metaseq copied to clipboard

How to load an OPT-IML checkpoint and use it for inference?

Open linmou opened this issue 2 years ago • 7 comments

Thanks for your work on OPT-IML. But I am confused about how to load the OPT-IML checkpoints for inference.

Code

Following the instruction in Inference API, I run the script

python -m metaseq.scripts.reshard_fsdp \                                                                                                               
        --input-glob-pattern "/OPT-IML-30B/checkpoint_1_6000-model_part-*.pt" \                                        
        --output-shard-name "/OPT-IML-30B/reshard-checkpoint_1_6000-model_part-*.pt" \                                                      
        --num-output-shards 1 --skip-optimizer-state True --unflatten-weights True 

And get the

following bugs

2023-01-07 13:05:54,294 | metaseq.scripts.reshard_fsdp | Loading all sharded checkpoints to CPU
2023-01-07 13:08:02,081 | metaseq.scripts.reshard_fsdp | Resharding state dicts into 1 shard(s)
Traceback (most recent call last):
  File "/opt/conda/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/conda/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/mnt/bd/metaseq/metaseq/scripts/reshard_fsdp.py", line 242, in <module>
    fire.Fire(reshard_fsdp_checkpoints)
  File "/opt/conda/lib/python3.9/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/opt/conda/lib/python3.9/site-packages/fire/core.py", line 475, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/opt/conda/lib/python3.9/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/mnt/bd/metaseq/metaseq/scripts/reshard_fsdp.py", line 53, in reshard_fsdp_checkpoints
    resharded_state_dicts = reshard_fsdp_state_dicts(
  File "/mnt/bd/metaseq/metaseq/scripts/reshard_fsdp.py", line 76, in reshard_fsdp_state_dicts
    shard_metadata=[s["shard_metadata"] for s in shard_state_dicts],
  File "/mnt/bd/metaseq/metaseq/scripts/reshard_fsdp.py", line 76, in <listcomp>
    shard_metadata=[s["shard_metadata"] for s in shard_state_dicts],
KeyError: 'shard_metadata'

It seems that the checkpoint loses an important key 'shard_metadata'. Could you provide any insights on fixing that?

linmou avatar Jan 07 '23 04:01 linmou

When this came up for OPT, I believe @stephenroller addressed the above via https://github.com/facebookresearch/metaseq/pull/60/files

cc @tangbinh to check difference with new reshard script

suchenzang avatar Jan 14 '23 07:01 suchenzang

Using the convert_to_singleton.py script above runs into a different error when used directly on the IML checkpoints:

['--model-parallel-size', '2', '--distributed-world-size', '2', '--ddp-backend', 'pytorch_ddp', '--task', 'language_modeling', '--bpe-merges', '../OPT-IML-30B/gpt2-merges.txt', '--merges-filename', '../OPT-IML-30B/gpt2-merges.txt', '--b
pe-vocab', '../OPT-IML-30B/gpt2-vocab.json', '--vocab-filename', '../OPT-IML-30B/gpt2-vocab.json', '--bpe', 'hf_byte_bpe', '--path', '../OPT-IML-30B/reshard.pt', '--checkpoint-shard-count', '1', '--use-sharded-state', '../OPT-IML-30B']
2023-01-16 05:49:19 | INFO | metaseq.distributed.utils | initialized host della-l09g7 as rank 1
2023-01-16 05:49:19 | INFO | metaseq.distributed.utils | initialized host della-l09g7 as rank 0
> initializing tensor model parallel with size 2
> initializing pipeline model parallel with size 1
> initializing model parallel cuda seeds on global rank 0, model parallel rank 0, and data parallel rank 0 with model parallel seed: 2719 and data parallel seed: 1
2023-01-16 05:50:29 | INFO | metaseq.checkpoint_utils | Done reading from disk
Traceback (most recent call last):
  File "/projects/GRP/USER/conda_envs/metaseq/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/projects/GRP/USER/conda_envs/metaseq/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/projects/GRP/USER/LLM/metaseq/metaseq/scripts/convert_to_singleton.py", line 174, in <module>
    main()
  File "/projects/GRP/USER/LLM/metaseq/metaseq/scripts/convert_to_singleton.py", line 170, in main
    distributed_utils.call_main(cfg, worker_main)
  File "/projects/GRP/USER/LLM/metaseq/metaseq/distributed/utils.py", line 287, in call_main
    return _spawn_helper(main, cfg, kwargs)
  File "/projects/GRP/USER/LLM/metaseq/metaseq/distributed/utils.py", line 265, in _spawn_helper
    retval = distributed_main(-1, main, cfg, kwargs)
  File "/projects/GRP/USER/LLM/metaseq/metaseq/distributed/utils.py", line 227, in distributed_main
    retval = main(cfg, **kwargs)
  File "/projects/GRP/USER/LLM/metaseq/metaseq/scripts/convert_to_singleton.py", line 119, in worker_main
    models, _model_args, _task = checkpoint_utils.load_model_ensemble_and_task(
  File "/projects/GRP/USER/LLM/metaseq/metaseq/checkpoint_utils.py", line 526, in load_model_ensemble_and_task
    model.load_state_dict(state["model"], strict=strict)
  File "/projects/GRP/USER/LLM/metaseq/metaseq/distributed/fully_sharded_data_parallel.py", line 76, in load_state_dict
    return super().load_local_state_dict(state_dict, strict=strict)
  File "/projects/GRP/USER/LLM/fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 1032, in load_local_state_dict
    output = self._load_state_dict(state_dict, strict)
  File "/projects/GRP/USER/LLM/fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 1008, in _load_state_dict
    return self.module.load_state_dict(state_dict, strict)
  File "/projects/GRP/USER/LLM/fairscale/fairscale/nn/misc/flatten_params_wrapper.py", line 478, in load_state_dict
    return super().load_state_dict(state_dict, strict)
  File "/home/USER/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for FlattenParamsWrapper:
        Missing key(s) in state_dict: "_fpw_module.decoder.layers.0._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.1._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.2._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.3._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.4._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.5._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.6._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.7._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.8._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.9._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.10._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.11._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.12._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.13._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.14._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.15._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.16._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.17._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.18._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.19._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.20._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.21._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.22._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.23._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.24._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.25._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.26._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.27._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.28._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.29._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.30._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.31._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.32._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.33._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.34._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.35._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.36._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.37._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.38._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.39._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.40._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.41._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.42._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.43._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.44._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.45._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.46._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.47._fsdp_wrapped_module.flat_param_0".
        Unexpected key(s) in state_dict: "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.qkv_proj.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.qkv_proj.bias", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.out_proj.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.out_proj.bias", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.bias", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.fc1.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.fc1.bias", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.fc2.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.fc2.bias", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.final_layer_norm.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.final_layer_norm.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.qkv_proj.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.qkv_proj.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.out_proj.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.out_proj.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.fc1.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.fc1.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.fc2.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.fc2.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.final_layer_norm.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.final_layer_norm.bias", "_fpw_module.decoder.layers.2._fsdp_wrapped

I circumvented this buy setting strict=False on Line 124 of convert_to_singleton.py, but then trying to torch.load the resulting output results in a different error:

[ins] In [1]: import torch

[ins] In [2]: torch.load('restored.pt', torch.device('cpu'))
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[2], line 1
----> 1 torch.load('restored.pt', torch.device('cpu'))

File ~/.local/lib/python3.9/site-packages/torch/serialization.py:777, in load(f, map_location, pickle_module, weights_only, **pickle_load_args)
    772 if _is_zipfile(opened_file):
    773     # The zipfile reader is going to advance the current file position.
    774     # If we want to actually tail call to torch.jit.load, we need to
    775     # reset back to the original position.
    776     orig_position = opened_file.tell()
--> 777     with _open_zipfile_reader(opened_file) as opened_zipfile:
    778         if _is_torchscript_zip(opened_zipfile):
    779             warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive"
    780                           " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to"
    781                           " silence this warning)", UserWarning)

File ~/.local/lib/python3.9/site-packages/torch/serialization.py:282, in _open_zipfile_reader.__init__(self, name_or_buffer)
    281 def __init__(self, name_or_buffer) -> None:
--> 282     super(_open_zipfile_reader, self).__init__(torch._C.PyTorchFileReader(name_or_buffer))

RuntimeError: PytorchStreamReader failed reading zip archive: failed finding central directory

Any thoughts? Happy to be an extra pair of eyes/hands on any work converting OPT-IML.

muhark avatar Jan 16 '23 15:01 muhark

@suchenzang @stephenroller Could you help elaborate how convert_to_singleton.py helps with this case? It appears to be merging reshard*.pt files, but the OPT-IML checkpoint does not have reshard*.pt. We are encountering errors when converting checkpoint_1_4000.pt-model_part-*.pt to reshard*.pt due to missing 'shard_metadata'.

xiangjjj avatar Jan 23 '23 03:01 xiangjjj

@linmou I think the released OPT-IML checkpoints have already been consolidated, so there's no need to run the script reshard_fsdp.py to merge FSDP shards (that's why each model parallel part *model_part-*.pt doesn't have any shard metadata). You can load these model parallel parts using interactive scripts, for example, following these instructions.

tangbinh avatar Feb 02 '23 22:02 tangbinh

@muhark @xiangjjj I don't think the release OPT-IML checkpoints work with convert_to_singleton.py, which is supposed to merge all model parallel parts (i.e. *model_part-*.pt files) into a singleton and expects inputs to contain flattened FSDP weights.

You don't need to merge the model parallel parts in order to load the checkpoints, though. For example, you can make some config changes following these instructions (e.g. setting MODEL_FILE to /OPT-IML-30B/checkpoint_1_6000.pt) and load these checkpoints as is via the referenced interactive scripts.

That said, if you really want to merge the model parallel parts, please try the new script reshard_mp.py.

tangbinh avatar Feb 02 '23 22:02 tangbinh

@tangbinh thanks for the explanation. The OPT-IML 175B checkpoints come with 16 model parallel parts, how can we run with 8 GPUs (MP8) as in the original OPT 175B?

frankxu2004 avatar Feb 05 '23 05:02 frankxu2004

@frankxu2004 If you want to convert from MP16 to MP8, please try the reshard_mp.py script (with --num-output-parts 8).

tangbinh avatar Feb 06 '23 18:02 tangbinh

It's been a while without any update, so I'm closing the issue now. Please let us know if you need further help.

tangbinh avatar Apr 05 '23 18:04 tangbinh