quanto icon indicating copy to clipboard operation
quanto copied to clipboard

Non-strict loading of the state dict

Open BenjaminBossan opened this issue 6 months ago • 9 comments

Hi, I'm currently investigating the addition of optimum-quanto to PEFT. This mostly works already but I'm hitting a wall when it comes to loading the state_dict. When loading a PEFT adapter like LoRA, we typically assume that the base model weights are already correctly loaded and are thus only interested in loading the adapter weights. That's why we're calling model.load_state_dict(peft_model_state_dict, strict=False) and ignore the missing keys.

Now when I try to do that with a quanto-model, I get an error about missing keys despite having strict=False. Below is a reproducer that does not involve PEFT for simplification:

from transformers import AutoModelForCausalLM
from optimum.quanto import quantize, qint8

model_id = "facebook/opt-125m"

# FIRST WITHOUT QUANTO
model = AutoModelForCausalLM.from_pretrained(model_id)
sd = model.state_dict()
weight = sd.pop("model.decoder.layers.0.self_attn.k_proj.weight")  # delete one item
# try with strict=True
try:
    model.load_state_dict(sd)
except RuntimeError as e:
    print(e)
# as expcted, prints:
# Error(s) in loading state_dict for OPTForCausalLM:
#	Missing key(s) in state_dict: "model.decoder.layers.0.self_attn.k_proj.weight".

# now strict=False
model.load_state_dict(sd, strict=False)
# passes and returns
# _IncompatibleKeys(missing_keys=['model.decoder.layers.0.self_attn.k_proj.weight'], unexpected_keys=[])

# SECOND WITH QUANTO
model = AutoModelForCausalLM.from_pretrained(model_id)
quantize(model, weights=qint8)
sd = model.state_dict()
weight = sd.pop("model.decoder.layers.0.self_attn.k_proj.weight")

# try with strict=True
try:
    model.load_state_dict(sd)
except KeyError as e:  # KeyError, not RuntimeError
    print(e)
# prints:
# 'model.decoder.layers.0.self_attn.k_proj.weight._data'

# now strict=False
model.load_state_dict(sd, strict=False)
# same KeyError as with strict=True

Full error:

KeyError                                  Traceback (most recent call last)
Cell In[15], line 1
----> 1 model.load_state_dict(sd, strict=False)

File ~/anaconda3/envs/peft/lib/python3.11/site-packages/torch/nn/modules/module.py:2201, in Module.load_state_dict(self, state_dict, strict, assign)
   2194         out = hook(module, incompatible_keys)
   2195         assert out is None, (
   2196             "Hooks registered with ``register_load_state_dict_post_hook`` are not"
   2197             "expected to return new values, if incompatible_keys need to be modified,"
   2198             "it should be done inplace."
   2199         )
-> 2201 load(self, state_dict)
   2202 del load
   2204 if strict:

File ~/anaconda3/envs/peft/lib/python3.11/site-packages/torch/nn/modules/module.py:2189, in Module.load_state_dict.<locals>.load(module, local_state_dict, prefix)
   2187         child_prefix = prefix + name + '.'
   2188         child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
-> 2189         load(child, child_state_dict, child_prefix)  # noqa: F821
   2191 # Note that the hook can modify missing_keys and unexpected_keys.
   2192 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)

File ~/anaconda3/envs/peft/lib/python3.11/site-packages/torch/nn/modules/module.py:2189, in Module.load_state_dict.<locals>.load(module, local_state_dict, prefix)
   2187         child_prefix = prefix + name + '.'
   2188         child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
-> 2189         load(child, child_state_dict, child_prefix)  # noqa: F821
   2191 # Note that the hook can modify missing_keys and unexpected_keys.
   2192 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)

    [... skipping similar frames: Module.load_state_dict.<locals>.load at line 2189 (3 times)]

File ~/anaconda3/envs/peft/lib/python3.11/site-packages/torch/nn/modules/module.py:2189, in Module.load_state_dict.<locals>.load(module, local_state_dict, prefix)
   2187         child_prefix = prefix + name + '.'
   2188         child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
-> 2189         load(child, child_state_dict, child_prefix)  # noqa: F821
   2191 # Note that the hook can modify missing_keys and unexpected_keys.
   2192 incompatible_keys = _IncompatibleKeys(missing_keys, unexpected_keys)

File ~/anaconda3/envs/peft/lib/python3.11/site-packages/torch/nn/modules/module.py:2183, in Module.load_state_dict.<locals>.load(module, local_state_dict, prefix)
   2181 if assign:
   2182     local_metadata['assign_to_params_buffers'] = assign
-> 2183 module._load_from_state_dict(
   2184     local_state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
   2185 for name, child in module._modules.items():
   2186     if child is not None:

File ~/anaconda3/envs/peft/lib/python3.11/site-packages/optimum/quanto/nn/qmodule.py:159, in QModuleMixin._load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
    157 weight_prefix = weight_name + "."
    158 if self.weight_qtype.bits == 8:
--> 159     deserialized_weight = QBytesTensor.load_from_state_dict(
    160         state_dict,
    161         weight_prefix,
    162         qtype=self.weight_qtype,
    163         axis=0,
    164         size=self.weight.size(),
    165         stride=self.weight.stride(),
    166     )
    167 else:
    168     deserialized_weight = QBitsTensor.load_from_state_dict(
    169         state_dict,
    170         weight_prefix,
   (...)
    175         stride=self.weight.stride(),
    176     )

File ~/anaconda3/envs/peft/lib/python3.11/site-packages/optimum/quanto/tensor/qbytes.py:90, in QBytesTensor.load_from_state_dict(state_dict, prefix, qtype, axis, size, stride)
     88 inner_tensors_dict = {}
     89 for name in ["_data", "_scale"]:
---> 90     inner_tensors_dict[name] = state_dict.pop(prefix + name)
     91 meta = {
     92     "qtype": qtype.name,
     93     "axis": str(axis),
     94     "size": str(list(size)),
     95     "stride": str(list(stride)),
     96 }
     97 return QBytesTensor.__tensor_unflatten__(inner_tensors_dict, meta, None, None)

KeyError: 'model.decoder.layers.0.self_attn.k_proj.weight._data'

BenjaminBossan avatar Aug 12 '24 14:08 BenjaminBossan