metaseq
metaseq copied to clipboard
How to load an OPT-IML checkpoint and use it for inference?
Thanks for your work on OPT-IML. But I am confused about how to load the OPT-IML checkpoints for inference.
Code
Following the instruction in Inference API, I run the script
python -m metaseq.scripts.reshard_fsdp \
--input-glob-pattern "/OPT-IML-30B/checkpoint_1_6000-model_part-*.pt" \
--output-shard-name "/OPT-IML-30B/reshard-checkpoint_1_6000-model_part-*.pt" \
--num-output-shards 1 --skip-optimizer-state True --unflatten-weights True
And get the
following bugs
2023-01-07 13:05:54,294 | metaseq.scripts.reshard_fsdp | Loading all sharded checkpoints to CPU
2023-01-07 13:08:02,081 | metaseq.scripts.reshard_fsdp | Resharding state dicts into 1 shard(s)
Traceback (most recent call last):
File "/opt/conda/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/opt/conda/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/mnt/bd/metaseq/metaseq/scripts/reshard_fsdp.py", line 242, in <module>
fire.Fire(reshard_fsdp_checkpoints)
File "/opt/conda/lib/python3.9/site-packages/fire/core.py", line 141, in Fire
component_trace = _Fire(component, args, parsed_flag_args, context, name)
File "/opt/conda/lib/python3.9/site-packages/fire/core.py", line 475, in _Fire
component, remaining_args = _CallAndUpdateTrace(
File "/opt/conda/lib/python3.9/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
component = fn(*varargs, **kwargs)
File "/mnt/bd/metaseq/metaseq/scripts/reshard_fsdp.py", line 53, in reshard_fsdp_checkpoints
resharded_state_dicts = reshard_fsdp_state_dicts(
File "/mnt/bd/metaseq/metaseq/scripts/reshard_fsdp.py", line 76, in reshard_fsdp_state_dicts
shard_metadata=[s["shard_metadata"] for s in shard_state_dicts],
File "/mnt/bd/metaseq/metaseq/scripts/reshard_fsdp.py", line 76, in <listcomp>
shard_metadata=[s["shard_metadata"] for s in shard_state_dicts],
KeyError: 'shard_metadata'
It seems that the checkpoint loses an important key 'shard_metadata'. Could you provide any insights on fixing that?