onnx-coreml icon indicating copy to clipboard operation
onnx-coreml copied to clipboard

Add support for BatchNorm2d (and more options for Slice)

Open jbmaxwell opened this issue 5 years ago • 1 comments

I'm trying to convert the ClusterGAN and the conversion fails in two places. The first error is a Slice in the encoder that fails in the forward() function:

def forward(self, in_feat):
        z_img = self.model(in_feat)
        # Reshape for output
        z = z_img.view(z_img.shape[0], -1)
        # Separate continuous and one-hot components
        zn = z[:, 0:self.latent_dim]
        zc_logits = z[:, self.latent_dim:]
        # Softmax on zc component
        zc = softmax(zc_logits)
        return zn, zc, zc_logits

I'm able to work around that error by moving the Slice out of the converted model and handling it in Swift. However, the 2nd error, which I'm not sure I can work around, is a BatchNorm2d() in the generator. ClusterGAN is a great solution for my purposes, so I'd love to get this model converted.

jbmaxwell avatar Mar 29 '20 00:03 jbmaxwell

It seems I'm mistaken about the error here. I tried doing some reshaping to replace the BatchNorm2d with a BatchNorm1d, but I still get the error: Error while converting op of type: BatchNormalization. Error message: provided number axes 2 not supported.

The (original) model structure is:

self.model = nn.Sequential(
            # Fully connected layers
            torch.nn.Linear(self.latent_dim + self.n_c, 1024),
            nn.BatchNorm1d(1024),
            #torch.nn.ReLU(True),
            nn.LeakyReLU(0.2, inplace=True),
            torch.nn.Linear(1024, self.iels),
            nn.BatchNorm1d(self.iels),
            #torch.nn.ReLU(True),
            nn.LeakyReLU(0.2, inplace=True),
        
            # Reshape to 128 x (7x7)
            Reshape(self.ishape),

            # Upconvolution layers
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1, bias=True),
            nn.BatchNorm2d(64),
            #torch.nn.ReLU(True),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.ConvTranspose2d(64, 1, 4, stride=2, padding=1, bias=True),
            nn.Sigmoid()
        )

jbmaxwell avatar Mar 29 '20 16:03 jbmaxwell