ncnn icon indicating copy to clipboard operation
ncnn copied to clipboard

DPRNN convert problem with pnnx

Open SherryYu33 opened this issue 11 months ago • 0 comments

I try to convert a model called DPRNN with pnnx, it did not report error, but the output shape is inconsistent with input shape. Moreover, if I permute it to the correct shape, the ncnn output is also different from torch's output

error log | 日志或报错信息 | ログ

model | 模型 | モデル

how to reproduce | 复现步骤 | 再現方法

  1. export DPRNN with python with pnnx
import pnnx
import torch
import torch.nn as nn

class DPRNN(nn.Module):
    def __init__(self, input_size, width, hidden_size):
        super().__init__()
        self.intra_rnn = nn.GRU(input_size, hidden_size//2, 2, batch_first=True, bidirectional=True)
        self.intra_fc = nn.Linear(hidden_size, hidden_size)
        self.intra_ln = nn.LayerNorm([width, hidden_size], eps=1e-8)
        
        self.inter_rnn = nn.GRU(input_size, hidden_size, 2, batch_first=True)
        self.inter_fc = nn.Linear(hidden_size, hidden_size)
        self.inter_ln = nn.LayerNorm([width, hidden_size], eps=1e-8)
        
    def forward(self, x, state):
        # input shape (B,C,T,F) --> (B,T,F,C)
        skip = x.permute(0, 2, 3, 1).contiguous()
                        
        ## intra
        # input shape (B,T,F,C) --> (B*T,F,C)
        out = skip.view(1, 4, 16)
        # (B*T,F,C)
        out = self.intra_rnn(out)[0]
        out = self.intra_fc(out)
        # (B*T,F,C) --> (B,T,F,C)
        out = out.view(1, 1, 4, 16)
        out = self.intra_ln(out)
        skip = out+skip #(B,T,F,C)
        
        ## inter
        # input shape (B,T,F,C) --> (B*F,T,C)
        out = skip.transpose(1, 2).contiguous()
        out = out.view(4, 1, 16)
        # (B*F,T,C)
        out, state = self.inter_rnn(out, state)
        out = self.inter_fc(out)
        # (B*F,T,C) --> (B,T,F,C)
        out = out.view(1, 4, 1, 16)
        out = out.transpose(1, 2).contiguous()
        out = self.inter_ln(out)
        skip = out+skip #(B,T,F,C)

        # output shape (B,T,F,C) --> (B,C,T,F)
        skip = skip.permute(0, 3, 1, 2).contiguous()
        return skip, state

net = DPRNN(16, 4, 16).eval()
x = torch.rand((1, 16, 1, 4))
state = torch.rand((2, 4, 16))
# opt_net = pnnx.export(net, "dprnn.pt", [x, state])
mod = torch.jit.trace(net, [x, state])
mod.save("dprnn.pt")
opt_net = pnnx.convert("dprnn.pt", [x, state])
  1. test it
import numpy as np
import ncnn
import torch
import torch.nn as nn

class DPRNN(nn.Module):
    def __init__(self, input_size, width, hidden_size):
        super().__init__()
        self.intra_rnn = nn.GRU(input_size, hidden_size//2, 2, batch_first=True, bidirectional=True)
        self.intra_fc = nn.Linear(hidden_size, hidden_size)
        self.intra_ln = nn.LayerNorm([width, hidden_size], eps=1e-8)
        
        self.inter_rnn = nn.GRU(input_size, hidden_size, 2, batch_first=True)
        self.inter_fc = nn.Linear(hidden_size, hidden_size)
        self.inter_ln = nn.LayerNorm([width, hidden_size], eps=1e-8)
        
    def forward(self, x, state):
        # input shape (B,C,T,F) --> (B,T,F,C)
        skip = x.permute(0, 2, 3, 1).contiguous()
                        
        ## intra
        # input shape (B,T,F,C) --> (B*T,F,C)
        out = skip.view(1, 4, 16)
        # (B*T,F,C)
        out = self.intra_rnn(out)[0]
        out = self.intra_fc(out)
        # (B*T,F,C) --> (B,T,F,C)
        out = out.view(1, 1, 4, 16)
        out = self.intra_ln(out)
        skip = out+skip #(B,T,F,C)
        
        ## inter
        # input shape (B,T,F,C) --> (B*F,T,C)
        out = skip.transpose(1, 2).contiguous()
        out = out.view(4, 1, 16)
        # (B*F,T,C)
        out, state = self.inter_rnn(out, state)
        out = self.inter_fc(out)
        # (B*F,T,C) --> (B,T,F,C)
        out = out.view(1, 4, 1, 16)
        out = out.transpose(1, 2).contiguous()
        out = self.inter_ln(out)
        skip = out+skip #(B,T,F,C)

        # output shape (B,T,F,C) --> (B,C,T,F)
        skip = skip.permute(0, 3, 1, 2).contiguous()
        return skip, state

def test_inference():
    torch.manual_seed(0)
    in0 = torch.rand(1, 16, 1, 4, dtype=torch.float)
    in1 = torch.rand(2, 4, 16, dtype=torch.float)
    out = []

    with ncnn.Net() as net:
        net.load_param("dprnn.ncnn.param")
        net.load_model("dprnn.ncnn.bin")

        with net.create_extractor() as ex:
            ex.input("in0", ncnn.Mat(in0.squeeze(0).numpy()).clone())
            ex.input("in1", ncnn.Mat(in1.squeeze(1).numpy()).clone())

            _, out0 = ex.extract("out0")
            out.append(torch.from_numpy(np.array(out0)).unsqueeze(0))
            _, out1 = ex.extract("out1")
            out.append(torch.from_numpy(np.array(out1)).unsqueeze(1))

    if len(out) == 1:
        return out[0]
    else:
        return tuple(out)
    
def test_inference_torch():
    torch.manual_seed(0)
    in0 = torch.rand(1, 16, 1, 4, dtype=torch.float)
    in1 = torch.rand(2, 4, 16, dtype=torch.float)
    
    net = DPRNN(16, 4, 16).eval()
    with torch.no_grad():
        out0, out1 = net(in0, in1)
    
    return tuple([out0, out1])

if __name__ == "__main__":
    # print(test_inference()[0].shape)
    print((test_inference()[0].permute(0, 3, 1, 2)-test_inference_torch()[0]).abs().max().item())
  1. result 5.277203559875488

SherryYu33 avatar Mar 08 '24 08:03 SherryYu33