SAM2-UNet icon indicating copy to clipboard operation
SAM2-UNet copied to clipboard

how to train with tiny checkpoints?

Open FeiYull opened this issue 5 months ago • 2 comments

i have done the following changes:

  1. https://github.com/WZH0120/SAM2-UNet/blob/eb1c38d870358cbdd769c9721062f7bb888ef9b5/train.py#L15
  2. edit the yaml https://github.com/WZH0120/SAM2-UNet/blob/eb1c38d870358cbdd769c9721062f7bb888ef9b5/SAM2UNet.py#L127

but errors occur like

python3.9/site-packages/torch/nn/modules/module.py", line 2215, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for SAM2Base: Missing key(s) in state_dict: "image_encoder.trunk.blocks.2.proj.weight", "image_encoder.trunk.blocks.2.proj.bias", "image_encoder.trunk.blocks.8.proj.weight", "image_encoder.trunk.blocks.8.proj.bias", "image_encoder.trunk.blocks.12.norm1.weight", "image_encoder.trunk.blocks.12.norm1.bias", "image_encoder.trunk.blocks.12.attn.qkv.weight", "image_encoder.trunk.blocks.12.attn.qkv.bias", "image_encoder.trunk.blocks.12.attn.proj.weight", "image_encoder.trunk.blocks.12.attn.proj.bias", "image_encoder.trunk.blocks.12.norm2.weight", "image_encoder.trunk.blocks.12.norm2.bias", "image_encoder.trunk.blocks.12.mlp.layers.0.weight", "

...........

Unexpected key(s) in state_dict: "image_encoder.trunk.blocks.1.proj.weight", "image_encoder.trunk.blocks.1.proj.bias", "image_encoder.trunk.blocks.3.proj.weight", "image_encoder.trunk.blocks.3.proj.bias", "image_encoder.trunk.blocks.10.proj.weight", "image_encoder.trunk.blocks.10.proj.bias". size mismatch for image_encoder.trunk.pos_embed: copying a param with shape torch.Size([1, 96, 7, 7]) from checkpoint, the shape in current model is torch.Size([1, 144, 7, 7]). size mismatch for image_encoder.trunk.pos_embed_window: copying a param with shape torch.Size([1, 96, 8, 8]) from checkpoint, the shape in current model is torch.Size([1, 144, 8, 8]). size mismatch for image_encoder.trunk.patch_embed.proj.weight: copying a param with shape torch.Size([96, 3, 7, 7]) from checkpoint, the shape in current model is torch.Size([144, 3, 7, 7]). size mismatch for image_encoder.trunk.patch_embed.proj.bias: copying a param with shape torch.Size([96]) from checkpoint, the shape in current model is torch.Size([144]).

.......

FeiYull avatar Sep 27 '24 09:09 FeiYull