Cell Type Annotation scGPT_Human model state dict not correct
Hi, thank you for your amazing work developing scGPT.
I am following the Tutorial_Annotation.ipynb tutorial using the ms dataset from https://drive.google.com/drive/folders/1Qd42YNabzyr2pWt9xoY4cVMTAxsNBt4v and the model from https://drive.google.com/drive/folders/1oWh_-ZRdhtoGQ2Fw24HP41FgLoomVo-y
For this section of the tutorial (model loading):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ntokens = len(vocab) # size of vocabulary
model = TransformerModel(
ntokens,
embsize,
nhead,
d_hid,
nlayers,
nlayers_cls=3,
n_cls=num_types if CLS else 1,
vocab=vocab,
dropout=dropout,
pad_token=pad_token,
pad_value=pad_value,
do_mvc=MVC,
do_dab=DAB,
use_batch_labels=INPUT_BATCH_LABELS,
num_batch_labels=num_batch_types,
domain_spec_batchnorm=config.DSBN,
input_emb_style=input_emb_style,
n_input_bins=n_input_bins,
cell_emb_style=cell_emb_style,
mvc_decoder_style=mvc_decoder_style,
ecs_threshold=ecs_threshold,
explicit_zero_prob=explicit_zero_prob,
use_fast_transformer=fast_transformer,
fast_transformer_backend=fast_transformer_backend,
pre_norm=config.pre_norm,
)
if config.load_model is not None:
try:
model.load_state_dict(torch.load(model_file))
logger.info(f"Loading all model params from {model_file}")
except:
# only load params that are in the model and match the size
model_dict = model.state_dict()
pretrained_dict = torch.load(model_file)
pretrained_dict = {
k: v
for k, v in pretrained_dict.items()
if k in model_dict and v.shape == model_dict[k].shape
}
for k, v in pretrained_dict.items():
logger.info(f"Loading params {k} with shape {v.shape}")
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
pre_freeze_param_count = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())
# Freeze all pre-decoder weights
for name, para in model.named_parameters():
print("-"*20)
print(f"name: {name}")
if config.freeze and "encoder" in name and "transformer_encoder" not in name:
# if config.freeze and "encoder" in name:
print(f"freezing weights for: {name}")
para.requires_grad = False
post_freeze_param_count = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters() if p.requires_grad).values())
logger.info(f"Total Pre freeze Params {(pre_freeze_param_count )}")
logger.info(f"Total Post freeze Params {(post_freeze_param_count )}")
wandb.log(
{
"info/pre_freeze_param_count": pre_freeze_param_count,
"info/post_freeze_param_count": post_freeze_param_count,
},
)
model.to(device)
wandb.watch(model)
if ADV:
discriminator = AdversarialDiscriminator(
d_model=embsize,
n_cls=num_batch_types,
).to(device)
I have the following error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[27], [line 35](vscode-notebook-cell:?execution_count=27&line=35)
[34](vscode-notebook-cell:?execution_count=27&line=34) try:
---> [35](vscode-notebook-cell:?execution_count=27&line=35) model.load_state_dict(torch.load(model_file,map_location= torch.device('mps')))
[36](vscode-notebook-cell:?execution_count=27&line=36) logger.info(f"Loading all model params from {model_file}")
File /Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:2189, in Module.load_state_dict(self, state_dict, strict, assign)
[2188](https://file+.vscode-resource.vscode-cdn.net/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:2188) if len(error_msgs) > 0:
-> [2189](https://file+.vscode-resource.vscode-cdn.net/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:2189) raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
[2190](https://file+.vscode-resource.vscode-cdn.net/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:2190) self.__class__.__name__, "\n\t".join(error_msgs)))
[2191](https://file+.vscode-resource.vscode-cdn.net/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:2191) return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for TransformerModel:
Missing key(s) in state_dict: "transformer_encoder.layers.0.self_attn.in_proj_weight", "transformer_encoder.layers.0.self_attn.in_proj_bias", "transformer_encoder.layers.1.self_attn.in_proj_weight", "transformer_encoder.layers.1.self_attn.in_proj_bias", "transformer_encoder.layers.2.self_attn.in_proj_weight", "transformer_encoder.layers.2.self_attn.in_proj_bias", "transformer_encoder.layers.3.self_attn.in_proj_weight", "transformer_encoder.layers.3.self_attn.in_proj_bias", "transformer_encoder.layers.4.self_attn.in_proj_weight", "transformer_encoder.layers.4.self_attn.in_proj_bias", "transformer_encoder.layers.5.self_attn.in_proj_weight", "transformer_encoder.layers.5.self_attn.in_proj_bias", "transformer_encoder.layers.6.self_attn.in_proj_weight", "transformer_encoder.layers.6.self_attn.in_proj_bias", "transformer_encoder.layers.7.self_attn.in_proj_weight", "transformer_encoder.layers.7.self_attn.in_proj_bias", "transformer_encoder.layers.8.self_attn.in_proj_weight", "transformer_encoder.layers.8.self_attn.in_proj_bias", "transformer_encoder.layers.9.self_attn.in_proj_weight", "transformer_encoder.layers.9.self_attn.in_proj_bias", "transformer_encoder.layers.10.self_attn.in_proj_weight", "transformer_encoder.layers.10.self_attn.in_proj_bias", "transformer_encoder.layers.11.self_attn.in_proj_weight", "transformer_encoder.layers.11.self_attn.in_proj_bias", "cls_decoder._decoder.0.weight", "cls_decoder._decoder.0.bias", "cls_decoder._decoder.2.weight", "cls_decoder._decoder.2.bias", "cls_decoder._decoder.3.weight", "cls_decoder._decoder.3.bias", "cls_decoder._decoder.5.weight", "cls_decoder._decoder.5.bias", "cls_decoder.out_layer.weight", "cls_decoder.out_layer.bias".
Unexpected key(s) in state_dict: "flag_encoder.weight", "mvc_decoder.gene2query.weight", "mvc_decoder.gene2query.bias", "mvc_decoder.W.weight", "transformer_encoder.layers.0.self_attn.Wqkv.weight", "transformer_encoder.layers.0.self_attn.Wqkv.bias", "transformer_encoder.layers.1.self_attn.Wqkv.weight", "transformer_encoder.layers.1.self_attn.Wqkv.bias", "transformer_encoder.layers.2.self_attn.Wqkv.weight", "transformer_encoder.layers.2.self_attn.Wqkv.bias", "transformer_encoder.layers.3.self_attn.Wqkv.weight", "transformer_encoder.layers.3.self_attn.Wqkv.bias", "transformer_encoder.layers.4.self_attn.Wqkv.weight", "transformer_encoder.layers.4.self_attn.Wqkv.bias", "transformer_encoder.layers.5.self_attn.Wqkv.weight", "transformer_encoder.layers.5.self_attn.Wqkv.bias", "transformer_encoder.layers.6.self_attn.Wqkv.weight", "transformer_encoder.layers.6.self_attn.Wqkv.bias", "transformer_encoder.layers.7.self_attn.Wqkv.weight", "transformer_encoder.layers.7.self_attn.Wqkv.bias", "transformer_encoder.layers.8.self_attn.Wqkv.weight", "transformer_encoder.layers.8.self_attn.Wqkv.bias", "transformer_encoder.layers.9.self_attn.Wqkv.weight", "transformer_encoder.layers.9.self_attn.Wqkv.bias", "transformer_encoder.layers.10.self_attn.Wqkv.weight", "transformer_encoder.layers.10.self_attn.Wqkv.bias", "transformer_encoder.layers.11.self_attn.Wqkv.weight", "transformer_encoder.layers.11.self_attn.Wqkv.bias".
I did not change any code from the tutorial, just loaded the datasets to the correct folders.
To help solve this issue, check this, for the transformer model, the number of states is different between the loaded and the model architecture (yes I am loading on "mps" because using MacM3 that does not have CUDA via GPU, it is not related to this)
To help you further please check this:
Keys in the current model state_dict but not in the loaded state_dict:
{'cls_decoder._decoder.3.bias', 'transformer_encoder.layers.8.self_attn.in_proj_weight', 'transformer_encoder.layers.8.self_attn.in_proj_bias', 'transformer_encoder.layers.2.self_attn.in_proj_bias', 'transformer_encoder.layers.10.self_attn.in_proj_bias', 'transformer_encoder.layers.11.self_attn.in_proj_bias', 'transformer_encoder.layers.7.self_attn.in_proj_weight', 'transformer_encoder.layers.1.self_attn.in_proj_weight', 'transformer_encoder.layers.3.self_attn.in_proj_weight', 'transformer_encoder.layers.11.self_attn.in_proj_weight', 'cls_decoder._decoder.2.bias', 'transformer_encoder.layers.7.self_attn.in_proj_bias', 'cls_decoder._decoder.3.weight', 'transformer_encoder.layers.9.self_attn.in_proj_bias', 'cls_decoder._decoder.5.bias', 'cls_decoder.out_layer.bias', 'transformer_encoder.layers.9.self_attn.in_proj_weight', 'cls_decoder.out_layer.weight', 'transformer_encoder.layers.5.self_attn.in_proj_weight', 'transformer_encoder.layers.4.self_attn.in_proj_bias', 'transformer_encoder.layers.6.self_attn.in_proj_weight', 'transformer_encoder.layers.4.self_attn.in_proj_weight', 'cls_decoder._decoder.0.weight', 'cls_decoder._decoder.2.weight', 'transformer_encoder.layers.3.self_attn.in_proj_bias', 'cls_decoder._decoder.5.weight', 'transformer_encoder.layers.5.self_attn.in_proj_bias', 'transformer_encoder.layers.0.self_attn.in_proj_bias', 'transformer_encoder.layers.0.self_attn.in_proj_weight', 'transformer_encoder.layers.1.self_attn.in_proj_bias', 'cls_decoder._decoder.0.bias', 'transformer_encoder.layers.6.self_attn.in_proj_bias', 'transformer_encoder.layers.10.self_attn.in_proj_weight', 'transformer_encoder.layers.2.self_attn.in_proj_weight'}
Keys in the loaded state_dict but not in the current model state_dict:
{'transformer_encoder.layers.9.self_attn.Wqkv.weight', 'mvc_decoder.gene2query.bias', 'transformer_encoder.layers.3.self_attn.Wqkv.weight', 'transformer_encoder.layers.1.self_attn.Wqkv.weight', 'transformer_encoder.layers.3.self_attn.Wqkv.bias', 'transformer_encoder.layers.6.self_attn.Wqkv.bias', 'transformer_encoder.layers.4.self_attn.Wqkv.bias', 'transformer_encoder.layers.9.self_attn.Wqkv.bias', 'mvc_decoder.gene2query.weight', 'flag_encoder.weight', 'transformer_encoder.layers.1.self_attn.Wqkv.bias', 'transformer_encoder.layers.8.self_attn.Wqkv.bias', 'transformer_encoder.layers.2.self_attn.Wqkv.bias', 'transformer_encoder.layers.11.self_attn.Wqkv.weight', 'transformer_encoder.layers.10.self_attn.Wqkv.bias', 'transformer_encoder.layers.6.self_attn.Wqkv.weight', 'transformer_encoder.layers.7.self_attn.Wqkv.bias', 'transformer_encoder.layers.7.self_attn.Wqkv.weight', 'transformer_encoder.layers.2.self_attn.Wqkv.weight', 'transformer_encoder.layers.5.self_attn.Wqkv.bias', 'transformer_encoder.layers.8.self_attn.Wqkv.weight', 'transformer_encoder.layers.4.self_attn.Wqkv.weight', 'transformer_encoder.layers.5.self_attn.Wqkv.weight', 'transformer_encoder.layers.0.self_attn.Wqkv.weight', 'transformer_encoder.layers.11.self_attn.Wqkv.bias', 'transformer_encoder.layers.10.self_attn.Wqkv.weight', 'transformer_encoder.layers.0.self_attn.Wqkv.bias', 'mvc_decoder.W.weight'}
In sum, the model architecture does not match the loaded model on google drive.
I've encountered a similar issue. Could someone please explain how I should resolve this problem?
Hi, just updating on this issue. Solved it by running on a linux machine on aws
is this related to your loading on "mps" ?