metaseq
metaseq copied to clipboard
How to load an OPT-IML checkpoint and use it for inference?
Thanks for your work on OPT-IML. But I am confused about how to load the OPT-IML checkpoints for inference.
Code
Following the instruction in Inference API, I run the script
python -m metaseq.scripts.reshard_fsdp \
--input-glob-pattern "/OPT-IML-30B/checkpoint_1_6000-model_part-*.pt" \
--output-shard-name "/OPT-IML-30B/reshard-checkpoint_1_6000-model_part-*.pt" \
--num-output-shards 1 --skip-optimizer-state True --unflatten-weights True
And get the
following bugs
2023-01-07 13:05:54,294 | metaseq.scripts.reshard_fsdp | Loading all sharded checkpoints to CPU
2023-01-07 13:08:02,081 | metaseq.scripts.reshard_fsdp | Resharding state dicts into 1 shard(s)
Traceback (most recent call last):
File "/opt/conda/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/opt/conda/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/mnt/bd/metaseq/metaseq/scripts/reshard_fsdp.py", line 242, in <module>
fire.Fire(reshard_fsdp_checkpoints)
File "/opt/conda/lib/python3.9/site-packages/fire/core.py", line 141, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/opt/conda/lib/python3.9/site-packages/fire/core.py", line 475, in _Fire
component, remaining_args = _CallAndUpdateTrace(
File "/opt/conda/lib/python3.9/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "/mnt/bd/metaseq/metaseq/scripts/reshard_fsdp.py", line 53, in reshard_fsdp_checkpoints
resharded_state_dicts = reshard_fsdp_state_dicts(
File "/mnt/bd/metaseq/metaseq/scripts/reshard_fsdp.py", line 76, in reshard_fsdp_state_dicts
shard_metadata=[s["shard_metadata"] for s in shard_state_dicts],
File "/mnt/bd/metaseq/metaseq/scripts/reshard_fsdp.py", line 76, in <listcomp>
shard_metadata=[s["shard_metadata"] for s in shard_state_dicts],
KeyError: 'shard_metadata'
It seems that the checkpoint loses an important key 'shard_metadata'. Could you provide any insights on fixing that?
When this came up for OPT, I believe @stephenroller addressed the above via https://github.com/facebookresearch/metaseq/pull/60/files
cc @tangbinh to check difference with new reshard script
Using the convert_to_singleton.py
script above runs into a different error when used directly on the IML checkpoints:
['--model-parallel-size', '2', '--distributed-world-size', '2', '--ddp-backend', 'pytorch_ddp', '--task', 'language_modeling', '--bpe-merges', '../OPT-IML-30B/gpt2-merges.txt', '--merges-filename', '../OPT-IML-30B/gpt2-merges.txt', '--b
pe-vocab', '../OPT-IML-30B/gpt2-vocab.json', '--vocab-filename', '../OPT-IML-30B/gpt2-vocab.json', '--bpe', 'hf_byte_bpe', '--path', '../OPT-IML-30B/reshard.pt', '--checkpoint-shard-count', '1', '--use-sharded-state', '../OPT-IML-30B']
2023-01-16 05:49:19 | INFO | metaseq.distributed.utils | initialized host della-l09g7 as rank 1
2023-01-16 05:49:19 | INFO | metaseq.distributed.utils | initialized host della-l09g7 as rank 0
> initializing tensor model parallel with size 2
> initializing pipeline model parallel with size 1
> initializing model parallel cuda seeds on global rank 0, model parallel rank 0, and data parallel rank 0 with model parallel seed: 2719 and data parallel seed: 1
2023-01-16 05:50:29 | INFO | metaseq.checkpoint_utils | Done reading from disk
Traceback (most recent call last):
File "/projects/GRP/USER/conda_envs/metaseq/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/projects/GRP/USER/conda_envs/metaseq/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/projects/GRP/USER/LLM/metaseq/metaseq/scripts/convert_to_singleton.py", line 174, in <module>
main()
File "/projects/GRP/USER/LLM/metaseq/metaseq/scripts/convert_to_singleton.py", line 170, in main
distributed_utils.call_main(cfg, worker_main)
File "/projects/GRP/USER/LLM/metaseq/metaseq/distributed/utils.py", line 287, in call_main
return _spawn_helper(main, cfg, kwargs)
File "/projects/GRP/USER/LLM/metaseq/metaseq/distributed/utils.py", line 265, in _spawn_helper
retval = distributed_main(-1, main, cfg, kwargs)
File "/projects/GRP/USER/LLM/metaseq/metaseq/distributed/utils.py", line 227, in distributed_main
retval = main(cfg, **kwargs)
File "/projects/GRP/USER/LLM/metaseq/metaseq/scripts/convert_to_singleton.py", line 119, in worker_main
models, _model_args, _task = checkpoint_utils.load_model_ensemble_and_task(
File "/projects/GRP/USER/LLM/metaseq/metaseq/checkpoint_utils.py", line 526, in load_model_ensemble_and_task
model.load_state_dict(state["model"], strict=strict)
File "/projects/GRP/USER/LLM/metaseq/metaseq/distributed/fully_sharded_data_parallel.py", line 76, in load_state_dict
return super().load_local_state_dict(state_dict, strict=strict)
File "/projects/GRP/USER/LLM/fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 1032, in load_local_state_dict
output = self._load_state_dict(state_dict, strict)
File "/projects/GRP/USER/LLM/fairscale/fairscale/nn/data_parallel/fully_sharded_data_parallel.py", line 1008, in _load_state_dict
return self.module.load_state_dict(state_dict, strict)
File "/projects/GRP/USER/LLM/fairscale/fairscale/nn/misc/flatten_params_wrapper.py", line 478, in load_state_dict
return super().load_state_dict(state_dict, strict)
File "/home/USER/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for FlattenParamsWrapper:
Missing key(s) in state_dict: "_fpw_module.decoder.layers.0._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.1._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.2._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.3._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.4._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.5._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.6._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.7._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.8._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.9._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.10._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.11._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.12._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.13._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.14._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.15._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.16._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.17._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.18._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.19._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.20._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.21._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.22._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.23._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.24._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.25._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.26._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.27._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.28._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.29._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.30._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.31._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.32._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.33._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.34._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.35._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.36._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.37._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.38._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.39._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.40._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.41._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.42._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.43._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.44._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.45._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.46._fsdp_wrapped_module.flat_param_0", "_fpw_module.decoder.layers.47._fsdp_wrapped_module.flat_param_0".
Unexpected key(s) in state_dict: "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.qkv_proj.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.qkv_proj.bias", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.out_proj.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn.out_proj.bias", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.bias", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.fc1.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.fc1.bias", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.fc2.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.fc2.bias", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.final_layer_norm.weight", "_fpw_module.decoder.layers.0._fsdp_wrapped_module._fpw_module.final_layer_norm.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.qkv_proj.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.qkv_proj.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.out_proj.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn.out_proj.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.self_attn_layer_norm.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.fc1.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.fc1.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.fc2.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.fc2.bias", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.final_layer_norm.weight", "_fpw_module.decoder.layers.1._fsdp_wrapped_module._fpw_module.final_layer_norm.bias", "_fpw_module.decoder.layers.2._fsdp_wrapped
I circumvented this buy setting strict=False
on Line 124 of convert_to_singleton.py
, but then trying to torch.load
the resulting output results in a different error:
[ins] In [1]: import torch
[ins] In [2]: torch.load('restored.pt', torch.device('cpu'))
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[2], line 1
----> 1 torch.load('restored.pt', torch.device('cpu'))
File ~/.local/lib/python3.9/site-packages/torch/serialization.py:777, in load(f, map_location, pickle_module, weights_only, **pickle_load_args)
772 if _is_zipfile(opened_file):
773 # The zipfile reader is going to advance the current file position.
774 # If we want to actually tail call to torch.jit.load, we need to
775 # reset back to the original position.
776 orig_position = opened_file.tell()
--> 777 with _open_zipfile_reader(opened_file) as opened_zipfile:
778 if _is_torchscript_zip(opened_zipfile):
779 warnings.warn("'torch.load' received a zip file that looks like a TorchScript archive"
780 " dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to"
781 " silence this warning)", UserWarning)
File ~/.local/lib/python3.9/site-packages/torch/serialization.py:282, in _open_zipfile_reader.__init__(self, name_or_buffer)
281 def __init__(self, name_or_buffer) -> None:
--> 282 super(_open_zipfile_reader, self).__init__(torch._C.PyTorchFileReader(name_or_buffer))
RuntimeError: PytorchStreamReader failed reading zip archive: failed finding central directory
Any thoughts? Happy to be an extra pair of eyes/hands on any work converting OPT-IML.
@suchenzang @stephenroller Could you help elaborate how convert_to_singleton.py
helps with this case? It appears to be merging reshard*.pt
files, but the OPT-IML checkpoint does not have reshard*.pt
. We are encountering errors when converting checkpoint_1_4000.pt-model_part-*.pt
to reshard*.pt
due to missing 'shard_metadata'.
@linmou I think the released OPT-IML checkpoints have already been consolidated, so there's no need to run the script reshard_fsdp.py
to merge FSDP shards (that's why each model parallel part *model_part-*.pt
doesn't have any shard metadata). You can load these model parallel parts using interactive scripts, for example, following these instructions.
@muhark @xiangjjj I don't think the release OPT-IML checkpoints work with convert_to_singleton.py
, which is supposed to merge all model parallel parts (i.e. *model_part-*.pt
files) into a singleton and expects inputs to contain flattened FSDP weights.
You don't need to merge the model parallel parts in order to load the checkpoints, though. For example, you can make some config changes following these instructions (e.g. setting MODEL_FILE
to /OPT-IML-30B/checkpoint_1_6000.pt
) and load these checkpoints as is via the referenced interactive scripts.
That said, if you really want to merge the model parallel parts, please try the new script reshard_mp.py
.
@tangbinh thanks for the explanation. The OPT-IML 175B checkpoints come with 16 model parallel parts, how can we run with 8 GPUs (MP8) as in the original OPT 175B?
@frankxu2004 If you want to convert from MP16 to MP8, please try the reshard_mp.py script (with --num-output-parts 8
).
It's been a while without any update, so I'm closing the issue now. Please let us know if you need further help.