Cylinder3D icon indicating copy to clipboard operation
Cylinder3D copied to clipboard

subm with same indice_key must have same kernel size

Open runzhangDL opened this issue 2 years ago • 23 comments

Traceback (most recent call last): File "train_cylinder_asym.py", line 167, in main(args) File "train_cylinder_asym.py", line 132, in main outputs = my_model(train_pt_fea_ten, train_vox_ten, train_batch_size) File "/home/autowise/data_ssd/zhangrun/Cylinder3D/conda_envs/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/home/autowise/data_ssd/zhangrun/Cylinder3D/network/cylinder_spconv_3d.py", line 44, in forward spatial_features = self.cylinder_3d_spconv_seg(features_3d, coords, batch_size) File "/home/autowise/data_ssd/zhangrun/Cylinder3D/conda_envs/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/home/autowise/data_ssd/zhangrun/Cylinder3D/network/segmentator_3d_asymm_spconv.py", line 290, in forward ret = self.downCntx(ret) File "/home/autowise/data_ssd/zhangrun/Cylinder3D/conda_envs/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/home/autowise/data_ssd/zhangrun/Cylinder3D/network/segmentator_3d_asymm_spconv.py", line 78, in forward shortcut = self.conv1_2(shortcut) File "/home/autowise/data_ssd/zhangrun/Cylinder3D/conda_envs/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/home/autowise/data_ssd/zhangrun/Cylinder3D/conda_envs/lib/python3.7/site-packages/spconv/pytorch/conv.py", line 374, in forward datas) File "/home/autowise/data_ssd/zhangrun/Cylinder3D/conda_envs/lib/python3.7/site-packages/spconv/pytorch/conv.py", line 470, in _check_subm_reuse_valid f"subm with same indice_key must have same kernel" ValueError: subm with same indice_key must have same kernel size, expect [1, 3, 3], this layer [3, 1, 3]

runzhangDL avatar Dec 16 '21 11:12 runzhangDL

By the way, could anyone help me understand what each of these parameters(from semantickitti.yaml) mean?

model_params: model_architecture: "cylinder_asym"

output_shape: - 480 - 360 - 32

fea_dim: 9 out_fea_dim: 256 num_class: 20 num_input_features: 16 use_norm: True init_size: 32

runzhangDL avatar Dec 16 '21 11:12 runzhangDL

I have the same problem, how did you solve that?

singal95 avatar Dec 23 '21 10:12 singal95

Same issue here, however I use spconv v2. I permuted the weights of a pretrained dictionary to fit the new Sp2conv format, I used PyTorch code along the lines of:

my_model_dict = model.state_dict()
for key in my_model_dict:
        if "<INSERT KEYWORD>" in str(key):
            change_tensor = my_model_dict[key].clone()
            my_model_dict[key] = change_tensor.permute(0, 1, 3, 2, 4)

model.load_state_dict(my_model_dict)
return model

@runorz @singal95 are you using spconv v2 or v1? It might be an issue that Cylinder3D (unfortunately) uses the old deprecated version. This fork by mtzhang1999 helped me.

L-Reichardt avatar Jan 13 '22 15:01 L-Reichardt

Hi @L-Reichardt , by making the changes from mtzhang1999 and your permutation order, were you able to load the checkpoints from the author here and get expected inference results? I am trying the changes from mtzhang1999 and am left with some size mismatches like the following that cannot be solved by permutation:

No match: cylinder_3d_spconv_seg.upBlock1.conv3.weight model: torch.Size([256, 1, 3, 3, 256]) checkpoint: torch.Size([3, 3, 3, 256, 256])

Did you make any additional modifications to the mtzhang1999 code? Thank you!

min2209 avatar Jan 28 '22 05:01 min2209

@min2209 Hello. No I didnt make additional changes. In combination both worked for me, I could load the pretrained weights without an error. However the results showed that something is still wrong, qualitatively I had bad performance.

At that point I stopped trying to make the weights fit. I trained for 2 epochs on the pretrained weights and the results were again normal. However I plan on retraining completely at some point.

L-Reichardt avatar Jan 29 '22 13:01 L-Reichardt

@L-Reichardt Thank you for your reply. In terms of the changes made by mtzhang1999, did you copy the changelog at https://github.com/xinge008/Cylinder3D/pull/113/commits/df586da89e6275b21370d46619409c42340e1401 exactly? There is one strange (in my opinion) change as follows on line 181 in the UpBlock:

#self.conv3 = conv3x3(out_filters, out_filters, indice_key=indice_key) self.conv3 = conv1x3(out_filters, out_filters, indice_key=indice_key)

Whereas all other changes introduced by mtzhang1999 were to change between conv1x3 and conv3x1, this one here seems to explicitly change the shape of the convolution in a way that permutation doesn't return the same shape.

By getting results to be normal, did you mean that you achieved ~67% mIoU on the validation set?

Thank you!

min2209 avatar Jan 29 '22 18:01 min2209

@min2209 yeah you are right, that change seems "inconsistent". Have you tried loading the weights with 3x3 conv? With normal results I meant that they were qualitatively normal. I reached 62%. Currently retraining completely, ill let you know in a few days what the result is (if the pc doesnt freeze again).

L-Reichardt avatar Jan 30 '22 00:01 L-Reichardt

@L-Reichardt - I tried loading with 3x3 and it doesn't load as I mentioned above even with permutations because 3x3 has 3x3x3 weights which cannot be reshaped to 1x3 or 3x1. Currently I have the following:

  • Using mtzhang1999's patch, I reached 62%, like your experiment.
  • I have read through the paper and the upgrade from spconv1 to spconv2 in greater detail, and discovered that we don't in fact need to change the convolution kernel dimension order. The spconv1 -> spconv2 upgrade should only move the input channel, output channel, and kernel size order around. Instead, I have reverted back to @xinge008 's original code. However, it appears that for spconv2 there is a strict requirement on the naming of the kernels' indice_keys. If a 3x1 and 1x3 kernel have the same indice_key, it causes problems. Therefore, within each Block, I added a unique number to each indice_key, and the model runs successfully without any conv changes. With this change, I was able to load the released weights with the permutation .permute(4, 0, 1, 2, 3), which agrees with the spconv2 breaking changes note. UNFORTUNATELY, this inference result is still wrong, so I am re-training again with this. The first epoch reached 54% so far.

For your reference, the modified version I'm using now is as follows here:

# -*- coding:utf-8 -*-
# author: Xinge
# @file: segmentator_3d_asymm_spconv.py

import numpy as np
#import spconv
import spconv.pytorch as spconv
import torch
from torch import nn


def conv3x3(in_planes, out_planes, stride=1, indice_key=None):
    return spconv.SubMConv3d(in_planes, out_planes, kernel_size=3, stride=stride,
                             padding=1, bias=False, indice_key=indice_key)


def conv1x3(in_planes, out_planes, stride=1, indice_key=None):
    return spconv.SubMConv3d(in_planes, out_planes, kernel_size=(1, 3, 3), stride=stride,
                             padding=(0, 1, 1), bias=False, indice_key=indice_key)


def conv1x1x3(in_planes, out_planes, stride=1, indice_key=None):
    return spconv.SubMConv3d(in_planes, out_planes, kernel_size=(1, 1, 3), stride=stride,
                             padding=(0, 0, 1), bias=False, indice_key=indice_key)


def conv1x3x1(in_planes, out_planes, stride=1, indice_key=None):
    return spconv.SubMConv3d(in_planes, out_planes, kernel_size=(1, 3, 1), stride=stride,
                             padding=(0, 1, 0), bias=False, indice_key=indice_key)


def conv3x1x1(in_planes, out_planes, stride=1, indice_key=None):
    return spconv.SubMConv3d(in_planes, out_planes, kernel_size=(3, 1, 1), stride=stride,
                             padding=(1, 0, 0), bias=False, indice_key=indice_key)


def conv3x1(in_planes, out_planes, stride=1, indice_key=None):
    return spconv.SubMConv3d(in_planes, out_planes, kernel_size=(3, 1, 3), stride=stride,
                             padding=(1, 0, 1), bias=False, indice_key=indice_key)


def conv1x1(in_planes, out_planes, stride=1, indice_key=None):
    return spconv.SubMConv3d(in_planes, out_planes, kernel_size=1, stride=stride,
                             padding=1, bias=False, indice_key=indice_key)


class ResContextBlock(nn.Module):
    def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1, indice_key=None):
        super(ResContextBlock, self).__init__()
        self.conv1 = conv1x3(in_filters, out_filters, indice_key=indice_key + "bef1")
        self.bn0 = nn.BatchNorm1d(out_filters)
        self.act1 = nn.LeakyReLU()
          
        self.conv1_2 = conv3x1(out_filters, out_filters, indice_key=indice_key + "bef2")
        # self.conv1_2 = conv1x3(out_filters, out_filters, indice_key=indice_key + "bef")

        self.bn0_2 = nn.BatchNorm1d(out_filters)
        self.act1_2 = nn.LeakyReLU()

        self.conv2 = conv3x1(in_filters, out_filters, indice_key=indice_key + "bef3")
        self.act2 = nn.LeakyReLU()
        self.bn1 = nn.BatchNorm1d(out_filters)

        self.conv3 = conv1x3(out_filters, out_filters, indice_key=indice_key + "bef4")
        # self.conv3 = conv3x1(out_filters, out_filters, indice_key=indice_key + "bef")
        self.act3 = nn.LeakyReLU()
        self.bn2 = nn.BatchNorm1d(out_filters)

        self.weight_initialization()

    def weight_initialization(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        shortcut = self.conv1(x)
        shortcut = shortcut.replace_feature(self.act1(shortcut.features))
        shortcut = shortcut.replace_feature(self.bn0(shortcut.features))

        shortcut = self.conv1_2(shortcut)
        shortcut = shortcut.replace_feature(self.act1_2(shortcut.features))
        shortcut = shortcut.replace_feature(self.bn0_2(shortcut.features))

        resA = self.conv2(x)
        resA = resA.replace_feature(self.act2(resA.features))
        reaA = resA.replace_feature(self.bn1(resA.features))

        resA = self.conv3(resA)
        resA = resA.replace_feature(self.act3(resA.features))
        resA = resA.replace_feature(self.bn2(resA.features))
        resA = resA.replace_feature(resA.features + shortcut.features)

        return resA


class ResBlock(nn.Module):
    def __init__(self, in_filters, out_filters, dropout_rate, kernel_size=(3, 3, 3), stride=1,
                 pooling=True, drop_out=True, height_pooling=False, indice_key=None):
        super(ResBlock, self).__init__()
        self.pooling = pooling
        self.drop_out = drop_out

        self.conv1 = conv3x1(in_filters, out_filters, indice_key=indice_key + "bef1")
        self.act1 = nn.LeakyReLU()
        self.bn0 = nn.BatchNorm1d(out_filters)

        self.conv1_2 = conv1x3(out_filters, out_filters, indice_key=indice_key + "bef2")
        # self.conv1_2 = conv3x1(out_filters, out_filters, indice_key=indice_key + "bef")
        self.act1_2 = nn.LeakyReLU()
        self.bn0_2 = nn.BatchNorm1d(out_filters)

        self.conv2 = conv1x3(in_filters, out_filters, indice_key=indice_key + "bef3")
        self.act2 = nn.LeakyReLU()
        self.bn1 = nn.BatchNorm1d(out_filters)

        self.conv3 = conv3x1(out_filters, out_filters, indice_key=indice_key + "bef4")
        # self.conv3 = conv1x3(out_filters, out_filters, indice_key=indice_key + "bef")
        self.act3 = nn.LeakyReLU()
        self.bn2 = nn.BatchNorm1d(out_filters)

        if pooling:
            if height_pooling:
                self.pool = spconv.SparseConv3d(out_filters, out_filters, kernel_size=3, stride=2,
                                                padding=1, indice_key=indice_key, bias=False)
            else:
                self.pool = spconv.SparseConv3d(out_filters, out_filters, kernel_size=3, stride=(2, 2, 1),
                                                padding=1, indice_key=indice_key, bias=False)
        self.weight_initialization()

    def weight_initialization(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        shortcut = self.conv1(x)
        shortcut = shortcut.replace_feature(self.act1(shortcut.features))
        shortcut = shortcut.replace_feature(self.bn0(shortcut.features))

        shortcut = self.conv1_2(shortcut)
        shortcut = shortcut.replace_feature(self.act1_2(shortcut.features))
        shortcut = shortcut.replace_feature(self.bn0_2(shortcut.features))

        resA = self.conv2(x)
        resA = resA.replace_feature(self.act2(resA.features))
        resA = resA.replace_feature(self.bn1(resA.features))

        resA = self.conv3(resA)
        resA = resA.replace_feature(self.act3(resA.features))
        resA = resA.replace_feature(self.bn2(resA.features))

        resA = resA.replace_feature(resA.features + shortcut.features)

        if self.pooling:
            resB = self.pool(resA)
            return resB, resA
        else:
            return resA


class UpBlock(nn.Module):
    def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), indice_key=None, up_key=None):
        super(UpBlock, self).__init__()
        # self.drop_out = drop_out
        self.trans_dilao = conv3x3(in_filters, out_filters, indice_key=indice_key + "new_up")
        self.trans_act = nn.LeakyReLU()
        self.trans_bn = nn.BatchNorm1d(out_filters)

        self.conv1 = conv1x3(out_filters, out_filters, indice_key=indice_key+'up1')
        self.act1 = nn.LeakyReLU()
        self.bn1 = nn.BatchNorm1d(out_filters)

        self.conv2 = conv3x1(out_filters, out_filters, indice_key=indice_key+'up2')
        # self.conv2 = conv1x3(out_filters, out_filters, indice_key=indice_key)
        self.act2 = nn.LeakyReLU()
        self.bn2 = nn.BatchNorm1d(out_filters)

        self.conv3 = conv3x3(out_filters, out_filters, indice_key=indice_key+'up3')
        self.act3 = nn.LeakyReLU()
        self.bn3 = nn.BatchNorm1d(out_filters)
        # self.dropout3 = nn.Dropout3d(p=dropout_rate)

        self.up_subm = spconv.SparseInverseConv3d(out_filters, out_filters, kernel_size=3, indice_key=up_key,
                                                  bias=False)

        self.weight_initialization()

    def weight_initialization(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x, skip):
        upA = self.trans_dilao(x)
        upA = upA.replace_feature(self.trans_act(upA.features))
        upA = upA.replace_feature(self.trans_bn(upA.features))

        ## upsample
        upA = self.up_subm(upA)

        upA = upA.replace_feature(upA.features + skip.features)

        upE = self.conv1(upA)
        upE = upE.replace_feature(self.act1(upE.features))
        upE = upE.replace_feature(self.bn1(upE.features))

        upE = self.conv2(upE)
        upE = upE.replace_feature(self.act2(upE.features))
        upE = upE.replace_feature(self.bn2(upE.features))

        upE = self.conv3(upE)
        upE = upE.replace_feature(self.act3(upE.features))
        upE = upE.replace_feature(self.bn3(upE.features))

        return upE


class ReconBlock(nn.Module):
    def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1, indice_key=None):
        super(ReconBlock, self).__init__()
        self.conv1 = conv3x1x1(in_filters, out_filters, indice_key=indice_key + "bef1")
        self.bn0 = nn.BatchNorm1d(out_filters)
        self.act1 = nn.Sigmoid()

        self.conv1_2 = conv1x3x1(in_filters, out_filters, indice_key=indice_key + "bef2")
        self.bn0_2 = nn.BatchNorm1d(out_filters)
        self.act1_2 = nn.Sigmoid()

        self.conv1_3 = conv1x1x3(in_filters, out_filters, indice_key=indice_key + "bef3")
        self.bn0_3 = nn.BatchNorm1d(out_filters)
        self.act1_3 = nn.Sigmoid()

    def forward(self, x):
        shortcut = self.conv1(x)
        shortcut = shortcut.replace_feature(self.bn0(shortcut.features))
        shortcut = shortcut.replace_feature(self.act1(shortcut.features))

        shortcut2 = self.conv1_2(x)
        shortcut2 = shortcut2.replace_feature(self.bn0_2(shortcut2.features))
        shortcut2 = shortcut2.replace_feature(self.act1_2(shortcut2.features))

        shortcut3 = self.conv1_3(x)
        shortcut3 = shortcut.replace_feature(self.bn0_3(shortcut3.features))
        shortcut3 = shortcut3.replace_feature(self.act1_3(shortcut3.features))
        shortcut = shortcut.replace_feature(shortcut.features + shortcut2.features + shortcut3.features)

        shortcut = shortcut.replace_feature(shortcut.features * x.features)

        return shortcut


class Asymm_3d_spconv(nn.Module):
    def __init__(self,
                 output_shape,
                 use_norm=True,
                 num_input_features=128,
                 nclasses=20, n_height=32, strict=False, init_size=16):
        super(Asymm_3d_spconv, self).__init__()
        self.nclasses = nclasses
        self.nheight = n_height
        self.strict = False

        sparse_shape = np.array(output_shape)
        # sparse_shape[0] = 11
        print(sparse_shape)
        self.sparse_shape = sparse_shape

        self.downCntx = ResContextBlock(num_input_features, init_size, indice_key="pre")
        self.resBlock2 = ResBlock(init_size, 2 * init_size, 0.2, height_pooling=True, indice_key="down2")
        self.resBlock3 = ResBlock(2 * init_size, 4 * init_size, 0.2, height_pooling=True, indice_key="down3")
        self.resBlock4 = ResBlock(4 * init_size, 8 * init_size, 0.2, pooling=True, height_pooling=False,
                                  indice_key="down4")
        self.resBlock5 = ResBlock(8 * init_size, 16 * init_size, 0.2, pooling=True, height_pooling=False,
                                  indice_key="down5")

        self.upBlock0 = UpBlock(16 * init_size, 16 * init_size, indice_key="up0", up_key="down5")
        self.upBlock1 = UpBlock(16 * init_size, 8 * init_size, indice_key="up1", up_key="down4")
        self.upBlock2 = UpBlock(8 * init_size, 4 * init_size, indice_key="up2", up_key="down3")
        self.upBlock3 = UpBlock(4 * init_size, 2 * init_size, indice_key="up3", up_key="down2")

        self.ReconNet = ReconBlock(2 * init_size, 2 * init_size, indice_key="recon")

        self.logits = spconv.SubMConv3d(4 * init_size, nclasses, indice_key="logit", kernel_size=3, stride=1, padding=1,
                                        bias=True)

    def forward(self, voxel_features, coors, batch_size):
        # x = x.contiguous()
        coors = coors.int()
        # import pdb
        # pdb.set_trace()
        ret = spconv.SparseConvTensor(voxel_features, coors, self.sparse_shape,
                                      batch_size)
        ret = self.downCntx(ret)
        down1c, down1b = self.resBlock2(ret)
        down2c, down2b = self.resBlock3(down1c)
        down3c, down3b = self.resBlock4(down2c)
        down4c, down4b = self.resBlock5(down3c)

        up4e = self.upBlock0(down4c, down4b)
        up3e = self.upBlock1(up4e, down3b)
        up2e = self.upBlock2(up3e, down2b)
        up1e = self.upBlock3(up2e, down1b)

        up0e = self.ReconNet(up1e)

        up0e = up0e.replace_feature(torch.cat((up0e.features, up1e.features), 1))

        logits = self.logits(up0e)
        y = logits.dense()
        return y

min2209 avatar Jan 30 '22 01:01 min2209

excuse me, when I test the pretrained model on KITTI 08 sequence, the results are all 0. Have you encountered this problem before?

GaloisWang avatar Apr 01 '22 06:04 GaloisWang

I have tried the code by min2209 but only reached 63.975 mIoU, using the default parameters. Anyone else was able to reach higher performance?

Montyro avatar Apr 05 '22 09:04 Montyro

I have tried the code by min2209 but only reached 63.975 mIoU, using the default parameters. Anyone else was able to reach higher performance?

I am the same.The best result I have got is 63.698.

GaloisWang avatar Apr 06 '22 02:04 GaloisWang

Yes, actually I was also unable to fully reproduce the authors' claimed numbers. I also got to around ~64 mIoU.

@GaloisWang - using the pre-trained model on the KITTI 08 sequence, I also get all 0's. Well I was never able to get the model loaded into the authors' code directly anyways.

I think one possibility is to actually instantiate one of these 1x3 and 3x1 blocks manually (e.g. single channel) and run it on a pre-determined binary input map. It should be easy to check whether the results are correct, and whether the shape is correct, according to the spconv1 -> spconv2 conversion policy in my code snippet above.

min2209 avatar Apr 06 '22 03:04 min2209

I've uploaded my weights in a fork here. @min2209 For your information, I've used your proposed code from this issue in the fork and credited it. I hope this is OK for you.

L-Reichardt avatar Jun 13 '22 07:06 L-Reichardt

I think there is a typo in this line: reaA = resA.replace_feature(self.bn1(resA.features)) should be: resA = resA.replace_feature(self.bn1(resA.features))

shayannikoohemat avatar Jul 12 '22 11:07 shayannikoohemat

@shayannikoohemat Thank you

L-Reichardt avatar Jul 18 '22 06:07 L-Reichardt

Thank you !! I was trapped by this problem for a long time. I tested the code on the colab, meeting many issues(file path & invalid versions of different codebase)..... QAQ> 我在colab上测试,和学习,被很多版本不同问题折磨(QAQ)。看了您的改进,可以正常跑通整个程序,非常感谢!

swimmerQAQ avatar Aug 11 '22 05:08 swimmerQAQ

@min2209 @GaloisWang @Montyro

Recently I have had the time to look over C3D again. This official repository does not contain the pointwise refinement module from the CVPR paper ( ref Issue ), but is actually the repository for the previous paper. In this paper a mIOU of 64.3 was reached (instead of 65.9 with the pointwise refinement module)

Considering there is a typo in this implementation (thanks @shayannikoohemat ) and the authors did not use manual seeding and did not implement weighted cross-entropy, our results are plausible and this likely explains the difference in results. When I find time, I will retrain the model without the typo (and maybe extend it with the pointwise refinement module).

L-Reichardt avatar Sep 30 '22 12:09 L-Reichardt

Updated weights (without typo) and mixed precision support now at this fork

L-Reichardt avatar Oct 04 '22 11:10 L-Reichardt

excuse me, when I test the pretrained model on KITTI 08 sequence, the results are all 0. Have you encountered this problem before?

I met the same problem, how did you solve it?

anran1231 avatar Nov 07 '22 08:11 anran1231

Updated weights (without typo) and mixed precision support now at this fork I got this error while using your code 21

Body123 avatar Mar 04 '23 19:03 Body123

@Body123 The error is caused by the versions of CUDA / SpConv. The code should run regardless.

L-Reichardt avatar Mar 05 '23 06:03 L-Reichardt

@Body123 The error is caused by the versions of CUDA / SpConv. The code should run regardless.

No I check all version required and the error still happen , can you type a requirement file ?

Body123 avatar Mar 05 '23 19:03 Body123

Hi everyone!

I successfully modified the code to the spconv 2.3.6!

I can obtain the same computation results in every network layers as the provided code in this repo for spconv 1.x. My code can test the provided pretrained model on the validation set as mIoU 66.91, which is the same as the performance reported in this repo. I modified two files utils/load_save_util.py and network/segmentator_3d_asymm_spconv.py.

The difference between the code for spconv 1.x and my code lies in three aspects:

  • In spconv 1.x, the weights are saved as RSCK, where RS is kernel size, C is input channel, K is output channel. While in spconv 2.x, the default weight layout is KRSC. See link for reference. I use torch.permute(4,0,1,2,3) to change the layout.

  • There is an intrinsic bug in the model architecture of Cylinder3D. In network/segmentator_3d_asymm_spconv.py, the conv3 in UpBlock is designed to have weight in size $3\times 3\times 3 = 27$. However, this convolution layer uses the indices key computed in the conv1, whose weight size is $1\times 3\times 3 = 9$. This bug leads to the weight values in conv3.weights[0, :, :] valid, but conv3.weights[1:, :, :] is nan. To renovate this problem, I redefine conv3 as conv1x3(out_filters, out_filters, indice_key=indice_key) in network/segmentator_3d_asymm_spconv.py. As for load_checkpoint function in utils/load_save_util.py, I only load the valid weights of conv3.

  • Due to many network layers using the same indices key but their weight sizes are different. If directly adopts the code for spconv 1.x into spconv 2.x, an error occurs as "subm with same indice_key must have same kernel size". This error is because the convolution function detects the kernel size of input data <input_data>.indice_dict[<given_indice_key>].ksize. I manually change the kernel size to the same size as the following convolution kernel size.

My modified code for load_checkpoint function in utils/load_save_util.py:

def load_checkpoint(model_load_path, model, device=None):
    my_model_dict = model.state_dict()
    if device is not None:
        pre_weight = torch.load(model_load_path, map_location='cuda:'+str(device))
    else:
        pre_weight = torch.load(model_load_path)
    part_load = {}
    match_size = 0
    nomatch_size = 0
    for k in pre_weight.keys():
        value = pre_weight[k]
        if k in my_model_dict and my_model_dict[k].shape == value.shape:
            # print("loading ", k)
            match_size += 1
            part_load[k] = value
        elif k in my_model_dict and my_model_dict[k].shape == value.permute(4,0,1,2,3).shape:
            match_size += 1
            part_load[k] = value.permute(4,0,1,2,3)
        elif k in my_model_dict and k.split('.')[-2] == 'conv3':
            match_size += 1
            part_load[k] = value[0].unsqueeze(0).permute(4,0,1,2,3)
        else:
            nomatch_size += 1

    print("matched parameter sets: {}, and no matched: {}".format(match_size, nomatch_size))

    my_model_dict.update(part_load)
    model.load_state_dict(my_model_dict)

    return model

My modified code in network/segmentator_3d_asymm_spconv.py:

from pickle import NONE
import numpy as np
import random
# import spconv
import spconv.pytorch as spconv
# import spconv.functional as Fsp
from spconv.pytorch import functional as Fsp
# from spconv import ops
from spconv.pytorch import ops
import torch
from torch import nn


def conv3x3(in_planes, out_planes, stride=1, indice_key=None):
    return spconv.SubMConv3d(in_planes, out_planes, kernel_size=3, stride=stride,
                             padding=1, bias=False, indice_key=indice_key)


def conv1x3(in_planes, out_planes, stride=1, indice_key=None):
    return spconv.SubMConv3d(in_planes, out_planes, kernel_size=(1, 3, 3), stride=stride,
                             padding=(0, 1, 1), bias=False, indice_key=indice_key)


def conv1x1x3(in_planes, out_planes, stride=1, indice_key=None):
    return spconv.SubMConv3d(in_planes, out_planes, kernel_size=(1, 1, 3), stride=stride,
                             padding=(0, 0, 1), bias=False, indice_key=indice_key)


def conv1x3x1(in_planes, out_planes, stride=1, indice_key=None):
    return spconv.SubMConv3d(in_planes, out_planes, kernel_size=(1, 3, 1), stride=stride,
                             padding=(0, 1, 0), bias=False, indice_key=indice_key)


def conv3x1x1(in_planes, out_planes, stride=1, indice_key=None):
    return spconv.SubMConv3d(in_planes, out_planes, kernel_size=(3, 1, 1), stride=stride,
                             padding=(1, 0, 0), bias=False, indice_key=indice_key)


def conv3x1(in_planes, out_planes, stride=1, indice_key=None):
    return spconv.SubMConv3d(in_planes, out_planes, kernel_size=(3, 1, 3), stride=stride,
                             padding=(1, 0, 1), bias=False, indice_key=indice_key)


def conv1x1(in_planes, out_planes, stride=1, indice_key=None):
    return spconv.SubMConv3d(in_planes, out_planes, kernel_size=1, stride=stride,
                             padding=1, bias=False, indice_key=indice_key)


class ResContextBlock(nn.Module):
    def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1, indice_key=None):
        super(ResContextBlock, self).__init__()
        self.indice_key = indice_key
        self.conv1 = conv1x3(in_filters, out_filters, indice_key=indice_key + "bef")
        self.bn0 = nn.BatchNorm1d(out_filters)
        self.act1 = nn.LeakyReLU()

        self.conv1_2 = conv3x1(out_filters, out_filters, indice_key=indice_key + "bef")
        self.bn0_2 = nn.BatchNorm1d(out_filters)
        self.act1_2 = nn.LeakyReLU()

        self.conv2 = conv3x1(in_filters, out_filters, indice_key=indice_key + "bef")
        self.act2 = nn.LeakyReLU()
        self.bn1 = nn.BatchNorm1d(out_filters)

        self.conv3 = conv1x3(out_filters, out_filters, indice_key=indice_key + "bef")
        self.act3 = nn.LeakyReLU()
        self.bn2 = nn.BatchNorm1d(out_filters)

        self.weight_initialization()

    def weight_initialization(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        shortcut = self.conv1(x)
        shortcut = shortcut.replace_feature(self.act1(shortcut.features))
        shortcut = shortcut.replace_feature(self.bn0(shortcut.features))
        
        shortcut.indice_dict[self.indice_key + "bef"].ksize = [3,1,3]
        shortcut = self.conv1_2(shortcut)
        shortcut = shortcut.replace_feature(self.act1_2(shortcut.features))
        shortcut = shortcut.replace_feature(self.bn0_2(shortcut.features))
        
        x.indice_dict[self.indice_key + "bef"] = shortcut.indice_dict[self.indice_key + "bef"]
        resA = self.conv2(x)
        resA = resA.replace_feature(self.act2(resA.features))
        resA = resA.replace_feature(self.bn1(resA.features))
        
        resA.indice_dict[self.indice_key + "bef"].ksize = [1,3,3]
        resA = self.conv3(resA)
        resA = resA.replace_feature(self.act3(resA.features))
        resA = resA.replace_feature(self.bn2(resA.features))
        resA = resA.replace_feature(resA.features + shortcut.features)

        return resA


class ResBlock(nn.Module):
    def __init__(self, in_filters, out_filters, dropout_rate, kernel_size=(3, 3, 3), stride=1,
                 pooling=True, drop_out=True, height_pooling=False, indice_key=None):
        super(ResBlock, self).__init__()
        self.pooling = pooling
        self.drop_out = drop_out
        self.height_pooling = height_pooling
        self.indice_key = indice_key

        self.conv1 = conv3x1(in_filters, out_filters, indice_key=indice_key + "bef")
        self.act1 = nn.LeakyReLU()
        self.bn0 = nn.BatchNorm1d(out_filters)

        self.conv1_2 = conv1x3(out_filters, out_filters, indice_key=indice_key + "bef")
        self.act1_2 = nn.LeakyReLU()
        self.bn0_2 = nn.BatchNorm1d(out_filters)

        self.conv2 = conv1x3(in_filters, out_filters, indice_key=indice_key + "bef")
        self.act2 = nn.LeakyReLU()
        self.bn1 = nn.BatchNorm1d(out_filters)

        self.conv3 = conv3x1(out_filters, out_filters, indice_key=indice_key + "bef")
        self.act3 = nn.LeakyReLU()
        self.bn2 = nn.BatchNorm1d(out_filters)

        if pooling:
            if height_pooling:
                self.pool = spconv.SparseConv3d(out_filters, out_filters, kernel_size=3, stride=2,
                                                padding=1, indice_key=indice_key, bias=False)
            else:
                self.pool = spconv.SparseConv3d(out_filters, out_filters, kernel_size=3, stride=(2, 2, 1),
                                                padding=1, indice_key=indice_key, bias=False)
        self.weight_initialization()

    def weight_initialization(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        shortcut = self.conv1(x)
        shortcut = shortcut.replace_feature(self.act1(shortcut.features))
        shortcut = shortcut.replace_feature(self.bn0(shortcut.features))
        
        shortcut.indice_dict[self.indice_key + "bef"].ksize = [1,3,3]
        shortcut = self.conv1_2(shortcut)
        shortcut = shortcut.replace_feature(self.act1_2(shortcut.features))
        shortcut = shortcut.replace_feature(self.bn0_2(shortcut.features))
        
        x.indice_dict[self.indice_key + "bef"] = shortcut.indice_dict[self.indice_key + "bef"]
        resA = self.conv2(x)
        resA = resA.replace_feature(self.act2(resA.features))
        resA = resA.replace_feature(self.bn1(resA.features))
        
        resA.indice_dict[self.indice_key + "bef"].ksize = [3,1,3]
        resA = self.conv3(resA)
        resA = resA.replace_feature(self.act3(resA.features))
        resA = resA.replace_feature(self.bn2(resA.features))
        resA = resA.replace_feature(resA.features + shortcut.features)

        if self.pooling:
            resB = self.pool(resA)
            return resB, resA
        else:
            return resA


class UpBlock(nn.Module):
    def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), indice_key=None, up_key=None):
        super(UpBlock, self).__init__()
        # self.drop_out = drop_out
        self.indice_key = indice_key
        self.trans_dilao = conv3x3(in_filters, out_filters, indice_key=indice_key + "new_up")
        self.trans_act = nn.LeakyReLU()
        self.trans_bn = nn.BatchNorm1d(out_filters)

        self.conv1 = conv1x3(out_filters, out_filters, indice_key=indice_key)
        self.act1 = nn.LeakyReLU()
        self.bn1 = nn.BatchNorm1d(out_filters)

        self.conv2 = conv3x1(out_filters, out_filters, indice_key=indice_key)
        self.act2 = nn.LeakyReLU()
        self.bn2 = nn.BatchNorm1d(out_filters)

        # self.conv3 = conv3x3(out_filters, out_filters, indice_key=indice_key)
        self.conv3 = conv1x3(out_filters, out_filters, indice_key=indice_key)
        self.act3 = nn.LeakyReLU()
        self.bn3 = nn.BatchNorm1d(out_filters)
        # self.dropout3 = nn.Dropout3d(p=dropout_rate)

        self.up_subm = spconv.SparseInverseConv3d(out_filters, out_filters, kernel_size=3, indice_key=up_key,
                                                  bias=False)

        self.weight_initialization()

    def weight_initialization(self):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x, skip):
        upA = self.trans_dilao(x)
        upA = upA.replace_feature(self.trans_act(upA.features))
        upA = upA.replace_feature(self.trans_bn(upA.features))
        
        upA = self.up_subm(upA)
        upA = upA.replace_feature(upA.features + skip.features)

        upE = self.conv1(upA)
        upE = upE.replace_feature(self.act1(upE.features))
        upE = upE.replace_feature(self.bn1(upE.features))
        
        upE.indice_dict[self.indice_key].ksize = [3,1,3]
        upE = self.conv2(upE)
        upE = upE.replace_feature(self.act2(upE.features))
        upE = upE.replace_feature(self.bn2(upE.features))
        
        upE.indice_dict[self.indice_key].ksize = [1,3,3]
        upE = self.conv3(upE)
        upE = upE.replace_feature(self.act3(upE.features))
        upE = upE.replace_feature(self.bn3(upE.features))
        
        return upE


class ReconBlock(nn.Module):
    def __init__(self, in_filters, out_filters, kernel_size=(3, 3, 3), stride=1, indice_key=None):
        super(ReconBlock, self).__init__()
        self.indice_key = indice_key
        self.conv1 = conv3x1x1(in_filters, out_filters, indice_key=indice_key + "bef")
        self.bn0 = nn.BatchNorm1d(out_filters)
        self.act1 = nn.Sigmoid()

        self.conv1_2 = conv1x3x1(in_filters, out_filters, indice_key=indice_key + "bef")
        self.bn0_2 = nn.BatchNorm1d(out_filters)
        self.act1_2 = nn.Sigmoid()

        self.conv1_3 = conv1x1x3(in_filters, out_filters, indice_key=indice_key + "bef")
        self.bn0_3 = nn.BatchNorm1d(out_filters)
        self.act1_3 = nn.Sigmoid()

    def forward(self, x):
        shortcut = self.conv1(x)
        shortcut = shortcut.replace_feature(self.bn0(shortcut.features))
        shortcut = shortcut.replace_feature(self.act1(shortcut.features))

        x.indice_dict[self.indice_key + "bef"] = shortcut.indice_dict[self.indice_key + "bef"]
        x.indice_dict[self.indice_key + "bef"].ksize = [1,3,1]
        shortcut2 = self.conv1_2(x)
        shortcut2 = shortcut2.replace_feature(self.bn0_2(shortcut2.features))
        shortcut2 = shortcut2.replace_feature(self.act1_2(shortcut2.features))

        x.indice_dict[self.indice_key + "bef"] = shortcut.indice_dict[self.indice_key + "bef"]
        x.indice_dict[self.indice_key + "bef"].ksize = [1,1,3]
        shortcut3 = self.conv1_3(x)
        shortcut3 = shortcut3.replace_feature(self.bn0_3(shortcut3.features))
        shortcut3 = shortcut3.replace_feature(self.act1_3(shortcut3.features))
        shortcut = shortcut3.replace_feature(shortcut.features + shortcut2.features + shortcut3.features)

        shortcut = shortcut.replace_feature(shortcut.features * x.features)

        return shortcut


class Asymm_3d_spconv(nn.Module):
    def __init__(self,
                 output_shape,
                 use_norm=True,
                 num_input_features=128,
                 nclasses=20, n_height=32, strict=False, init_size=16):
        super(Asymm_3d_spconv, self).__init__()
        self.nclasses = nclasses
        self.nheight = n_height
        self.strict = False

        sparse_shape = np.array(output_shape)
        self.sparse_shape = sparse_shape

        self.downCntx = ResContextBlock(num_input_features, init_size, indice_key="pre")
        self.resBlock2 = ResBlock(init_size, 2 * init_size, 0.2, height_pooling=True, indice_key="down2")
        self.resBlock3 = ResBlock(2 * init_size, 4 * init_size, 0.2, height_pooling=True, indice_key="down3")
        self.resBlock4 = ResBlock(4 * init_size, 8 * init_size, 0.2, pooling=True, height_pooling=False,
                                  indice_key="down4")
        self.resBlock5 = ResBlock(8 * init_size, 16 * init_size, 0.2, pooling=True, height_pooling=False,
                                  indice_key="down5")

        self.upBlock0 = UpBlock(16 * init_size, 16 * init_size, indice_key="up0", up_key="down5")
        self.upBlock1 = UpBlock(16 * init_size, 8 * init_size, indice_key="up1", up_key="down4")
        self.upBlock2 = UpBlock(8 * init_size, 4 * init_size, indice_key="up2", up_key="down3")
        self.upBlock3 = UpBlock(4 * init_size, 2 * init_size, indice_key="up3", up_key="down2")

        self.ReconNet = ReconBlock(2 * init_size, 2 * init_size, indice_key="recon")

        self.logits = spconv.SubMConv3d(4 * init_size, nclasses, indice_key="logit", kernel_size=3, stride=1, padding=1,
                                        bias=True)

    def forward(self, voxel_features, coors, batch_size):

        coors = coors.int()
        ret = spconv.SparseConvTensor(voxel_features, coors, self.sparse_shape,
                                      batch_size)
        ret = self.downCntx(ret)
        down1c, down1b = self.resBlock2(ret)
        down2c, down2b = self.resBlock3(down1c)
        down3c, down3b = self.resBlock4(down2c)
        down4c, down4b = self.resBlock5(down3c)

        up4e = self.upBlock0(down4c, down4b)
        up3e = self.upBlock1(up4e, down3b)
        up2e = self.upBlock2(up3e, down2b)
        up1e = self.upBlock3(up2e, down1b)
        up0e = self.ReconNet(up1e)

        up0e = up0e.replace_feature(torch.cat((up0e.features, up1e.features), 1))

        logits = self.logits(up0e)
        y = logits.dense()
        return y

antao97 avatar Jun 15 '23 03:06 antao97