ColossalAI
ColossalAI copied to clipboard
[BUG]: Failed to load HuggingFace pretrained checkpoint with LazyInitContext
π Describe the bug
Using LazyInitContext
and later loading checkpoint do not properly initialize model parameters.
import colossalai
from colossalai.lazy import LazyInitContext
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from transformers import GPT2LMHeadModel
colossalai.launch(config={}, rank=0, world_size=1, host="localhost", port=27700)
plugin = HybridParallelPlugin(tp_size=1, pp_size=1)
booster = Booster(plugin=plugin)
with LazyInitContext():
model = GPT2LMHeadModel.from_pretrained("gpt2")
model, *_ = booster.boost(model)
booster uses HybridParallelCheckpointIO.load_unsharded_model
to load GPT2 HF checkpoint.
# Load from checkpoint. Since the logic of breaking parameter shards along tp degree
# has been implemented by _load_from_state_dict method of ParallelModule in Shardformer,
# model.load_state_dict can be directly called.
state_dict = load_state_dict(checkpoint)
model.load_state_dict(state_dict, strict=strict)
When load_state_dict()
is called, because GPT2LMHeadModel wraps base GPT2Model with transformer
prefix, all GPT2 parameters are not initialized:
model.load_state_dict(state_dict, strict=strict)
missing_keys: ['transformer.wte.weight', 'transformer.wpe.weight', 'transformer.h.0.ln_1.weight', ...]
unexpected_keys: ['wte.weight', 'wpe.weight', 'h.0.ln_1.weight', ...]
list(model.parameters())[0]
Parameter containing:
tensor([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]], device='cuda:0', dtype=torch.float16,
requires_grad=True)
HuggingFace from_pretrained
method uses base_model_prefix
to adjust keys before loading parameters:
# Retrieve missing & unexpected_keys
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
prefix = model.base_model_prefix
original_loaded_keys = loaded_keys
loaded_keys = [_fix_key(key) for key in loaded_keys]
if len(prefix) > 0:
has_prefix_module = any(s.startswith(prefix) for s in loaded_keys)
expects_prefix_module = any(s.startswith(prefix) for s in expected_keys)
else:
has_prefix_module = False
expects_prefix_module = False
# key re-naming operations are never done on the keys
# that are loaded, but always on the keys of the newly initialized model
remove_prefix_from_model = not has_prefix_module and expects_prefix_module
add_prefix_to_model = has_prefix_module and not expects_prefix_module
if remove_prefix_from_model:
_prefix = f"{prefix}."
expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(_prefix)]
expected_keys = [s[len(_prefix) :] if s.startswith(_prefix) else s for s in expected_keys]
elif add_prefix_to_model:
expected_keys = [".".join([prefix, s]) for s in expected_keys]
...
Please let me know if I use LazyInitContext
incorrectly. Did I do something wrong?
Environment
Python 3.10.13 Torch 2.2.1 / CUDA 12.1 Transformers 4.38.2
@ver217
Bot detected the issue body's language is not English, translate it automatically. π―ππ»π§βπ€βπ§π«π§πΏβπ€βπ§π»π©πΎβπ€βπ¨πΏπ¬πΏ
@Faira17
Hi @insujang, the reason why applying LazyInitContext
does not load pretrained parameters into the model, is because under lazy init context, lazy tensors have not been materialized until they are used.
Based on your script, the GPT2LMHeadModel should work well with inference:
from transformers import pipeline
generator = pipeline('text-generation', model=MODEL_PATH)
prompt_text = "Once upon a time"
generated_texts = generator(prompt_text, max_length=50, num_return_sequences=1)
for text in generated_texts:
print(text['generated_text'])
Hi @char-1ee , thank you for your answer. I'm afraid I cannot understand your explanation. Your example does not even use LazyInitContext
at all.
Parameters are already materialized at the moment of calling load_model()
; in booster.boost()
it calls plugin.configure() -> shardformer.optimizer() -> modelsharder.shard()
; during sharding lazy parameters are sharded and materialized.
According to this document, we should use booster.load_model()
to load the parameter from the pretrained checkpoint. I understand that just applying LazyInitContext
does not load parameters, but my example includes calling booster.boost()
which automatically calls self.load_model()
:
https://github.com/hpcaitech/ColossalAI/blob/7e0ec5a85c73fcc5666b9d218e43865141587dde/colossalai/booster/booster.py#L151-L152
which uses `load_state_dict() and fails to load parameters.
Edit: also a LazyTensor doesn't print zeros. It will look like: LazyTensor(..., size=(x, y), device=..., dtype=...)
. This proves that parameters are materialized, as my example shows actual data in a materialized tensor.
Hi @insujang I mean the full script works for me,
import colossalai
from colossalai.lazy import LazyInitContext
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from transformers import GPT2LMHeadModel
MODEL_PATH = "/path/to/your/gpt2/model"
colossalai.launch(config={}, rank=0, world_size=1, host="localhost", port=27700)
plugin = HybridParallelPlugin(tp_size=1, pp_size=1)
booster = Booster(plugin=plugin)
with LazyInitContext():
model = GPT2LMHeadModel.from_pretrained(MODEL_PATH)
model, *_ = booster.boost(model)
# Test inference
generator = pipeline('text-generation', model=MODEL_PATH)
prompt_text = "Once upon a time"
generated_texts = generator(prompt_text, max_length=50, num_return_sequences=1)
for text in generated_texts:
print(text['generated_text'])
Yes, the lazy tensors should be materialized when sharding: ColossalAI/colossalai/shardformer/shard/sharder.py.
Also, the booster.boost()
calls booster.load_model()
, load_model()
will call the corresponding checkpoint IO to load the pretrained checkpoint, which is HybridParallelCheckpointIO
calls its load_unsharded_model()
here. Though the pretrained params are zero-like, but it will not report error from my side.
May I know how you get this error:
missing_keys: ['transformer.wte.weight', 'transformer.wpe.weight', 'transformer.h.0.ln_1.weight', ...]
unexpected_keys: ['wte.weight', 'wpe.weight', 'h.0.ln_1.weight', ...]
Oh I see. Have you checked parameters are properly loaded? I should have been more clear, the code will work without explicit error returned as you said. But I think you also saw parameters are zero-like, which is a bug.
May I know how you get this error:
https://github.com/hpcaitech/ColossalAI/blob/8e412a548e5366d1c42bcf386bd185091bd0c280/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py#L681 Print the result of this function call. Result is omitted in the production code.