quanto
quanto copied to clipboard
Non-strict loading of the state dict
Hi, I'm currently investigating the addition of optimum-quanto to PEFT. This mostly works already but I'm hitting a wall when it comes to loading the state_dict
. When loading a PEFT adapter like LoRA, we typically assume that the base model weights are already correctly loaded and are thus only interested in loading the adapter weights. That's why we're calling model.load_state_dict(peft_model_state_dict, strict=False)
and ignore the missing keys.
Now when I try to do that with a quanto-model, I get an error about missing keys despite having strict=False
. Below is a reproducer that does not involve PEFT for simplification:
from transformers import AutoModelForCausalLM
from optimum.quanto import quantize, qint8
model_id = "facebook/opt-125m"
# FIRST WITHOUT QUANTO
model = AutoModelForCausalLM.from_pretrained(model_id)
sd = model.state_dict()
weight = sd.pop("model.decoder.layers.0.self_attn.k_proj.weight") # delete one item
# try with strict=True
try:
model.load_state_dict(sd)
except RuntimeError as e:
print(e)
# as expcted, prints:
# Error(s) in loading state_dict for OPTForCausalLM:
# Missing key(s) in state_dict: "model.decoder.layers.0.self_attn.k_proj.weight".
# now strict=False
model.load_state_dict(sd, strict=False)
# passes and returns
# _IncompatibleKeys(missing_keys=['model.decoder.layers.0.self_attn.k_proj.weight'], unexpected_keys=[])
# SECOND WITH QUANTO
model = AutoModelForCausalLM.from_pretrained(model_id)
quantize(model, weights=qint8)
sd = model.state_dict()
weight = sd.pop("model.decoder.layers.0.self_attn.k_proj.weight")
# try with strict=True
try:
model.load_state_dict(sd)
except KeyError as e: # KeyError, not RuntimeError
print(e)
# prints:
# 'model.decoder.layers.0.self_attn.k_proj.weight._data'
# now strict=False
model.load_state_dict(sd, strict=False)
# same KeyError as with strict=True
Full error:
KeyError Traceback (most recent call last)
Cell In[15], line 1
----> 1 model.load_state_dict(sd, strict=False)
File ~/anaconda3/envs/peft/lib/python3.11/site-packages/torch/nn/modules/module.py:2201, in Module.load_state_dict(self, state_dict, strict, assign)
2194 out = hook(module, incompatible_keys)
2195 assert out is None, (
2196 "Hooks registered with ``register_load_state_dict_post_hook`` are not"
2197 "expected to return new values, if incompatible_keys need to be modified,"
2198 "it should be done inplace."
2199 )
-> 2201 load(self, state_dict)
2202 del load
2204 if strict:
File ~/anaconda3/envs/peft/lib/python3.11/site-packages/torch/nn/modules/module.py:2189, in Module.load_state_dict.<locals>.load(module, local_state_dict, prefix)
2187 child_prefix = prefix + name + '.'
2188 child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
-> 2189 load(child, child_state_dict, child_prefix) # noqa: F821
2191 # Note that the hook can modify missing_keys and unexpected_keys.
2192 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
File ~/anaconda3/envs/peft/lib/python3.11/site-packages/torch/nn/modules/module.py:2189, in Module.load_state_dict.<locals>.load(module, local_state_dict, prefix)
2187 child_prefix = prefix + name + '.'
2188 child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
-> 2189 load(child, child_state_dict, child_prefix) # noqa: F821
2191 # Note that the hook can modify missing_keys and unexpected_keys.
2192 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
[... skipping similar frames: Module.load_state_dict.<locals>.load at line 2189 (3 times)]
File ~/anaconda3/envs/peft/lib/python3.11/site-packages/torch/nn/modules/module.py:2189, in Module.load_state_dict.<locals>.load(module, local_state_dict, prefix)
2187 child_prefix = prefix + name + '.'
2188 child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
-> 2189 load(child, child_state_dict, child_prefix) # noqa: F821
2191 # Note that the hook can modify missing_keys and unexpected_keys.
2192 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)
File ~/anaconda3/envs/peft/lib/python3.11/site-packages/torch/nn/modules/module.py:2183, in Module.load_state_dict.<locals>.load(module, local_state_dict, prefix)
2181 if assign:
2182 local_metadata['assign_to_params_buffers'] = assign
-> 2183 module._load_from_state_dict(
2184 local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
2185 for name, child in module._modules.items():
2186 if child is not None:
File ~/anaconda3/envs/peft/lib/python3.11/site-packages/optimum/quanto/nn/qmodule.py:159, in QModuleMixin._load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
157 weight_prefix = weight_name + "."
158 if self.weight_qtype.bits == 8:
--> 159 deserialized_weight = QBytesTensor.load_from_state_dict(
160 state_dict,
161 weight_prefix,
162 qtype=self.weight_qtype,
163 axis=0,
164 size=self.weight.size(),
165 stride=self.weight.stride(),
166 )
167 else:
168 deserialized_weight = QBitsTensor.load_from_state_dict(
169 state_dict,
170 weight_prefix,
(...)
175 stride=self.weight.stride(),
176 )
File ~/anaconda3/envs/peft/lib/python3.11/site-packages/optimum/quanto/tensor/qbytes.py:90, in QBytesTensor.load_from_state_dict(state_dict, prefix, qtype, axis, size, stride)
88 inner_tensors_dict = {}
89 for name in ["_data", "_scale"]:
---> 90 inner_tensors_dict[name] = state_dict.pop(prefix + name)
91 meta = {
92 "qtype": qtype.name,
93 "axis": str(axis),
94 "size": str(list(size)),
95 "stride": str(list(stride)),
96 }
97 return QBytesTensor.__tensor_unflatten__(inner_tensors_dict, meta, None, None)
KeyError: 'model.decoder.layers.0.self_attn.k_proj.weight._data'