gtcrn icon indicating copy to clipboard operation
gtcrn copied to clipboard

StreamConvTranspose2d error

Open YIN-jd opened this issue 8 months ago • 2 comments

hello, i have a question. while StreamConvTranspose2d setting is follow, it have error. this is test code:
equivalent_conv =StreamConvTranspose2d( in_channels=in_channels, out_channels=hidden_channels, # kernel_size = 1, kernel_size=(1,5), stride=(1,2), padding=(0,2), groups= 2 ) original_layer = nn.ConvTranspose2d( in_channels=in_channels, out_channels=hidden_channels, # kernel_size = 1, kernel_size=(1,5), stride=(1,2), padding=(0,2), groups= 2

) x = torch.randn(1, in_channels, 16, 128) convert_to_stream(equivalent_conv, original_layer)

y_original = original_layer(x) y_equivalent = equivalent_conv(x) print(np.shape(y_original)) print(np.shape(y_equivalent)) assert torch.allclose(y_original, y_equivalent, atol=1e-6)

i found groups =2 it is error, but groups = 1 ,it is not error.

YIN-jd avatar Aug 21 '25 08:08 YIN-jd

These codes were written quite a long time ago, so I have forgotten the details. However, when I tried running your code, I encountered the following error: TypeError: StreamConvTranspose2d.forward() missing 1 required positional argument: 'cache'.

This suggests that in y_equivalent = equivalent_conv(x), you need to provide an additional argument cache. Did you not encounter the same error?

Xiaobin-Rong avatar Sep 01 '25 11:09 Xiaobin-Rong

These codes were written quite a long time ago, so I have forgotten the details. However, when I tried running your code, I encountered the following error: TypeError: StreamConvTranspose2d.forward() missing 1 required positional argument: 'cache'.

This suggests that in y_equivalent = equivalent_conv(x), you need to provide an additional argument cache. Did you not encounter the same error?

Sorry, I modified the StreamConvTranspose2d function code.

This is my revised version: class StreamConvTranspose2d(nn.Module): def init(self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]] = 1, padding: Union[str, int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, bias: bool = True, *args, **kargs): super().init(*args, **kargs) """ kernel_size = [T_size, F_size] by default stride = [T_stride, F_stride], and T_stride == 1 """ self.in_channels = in_channels self.out_channels = out_channels if type(kernel_size) is int: self.T_size = kernel_size self.F_size = kernel_size elif type(kernel_size) in [list, tuple]: self.T_size, self.F_size = kernel_size else: raise ValueError('Invalid kernel size.') if type(stride) is int: self.T_stride = stride self.F_stride = stride elif type(stride) in [list, tuple]: self.T_stride, self.F_stride = stride else: raise ValueError('Invalid stride size.')

    assert self.T_stride == 1

    if type(padding) is int:
        self.T_pad = padding
        self.F_pad = padding
    elif type(padding) in [list, tuple]:
        self.T_pad, self.F_pad = padding
    else:
        raise ValueError('Invalid padding size.')
    assert(self.T_pad == 0) 

    if type(dilation) is int:
        self.T_dilation = dilation
        self.F_dilation = dilation
    elif type(dilation) in [list, tuple]:
        self.T_dilation, self.F_dilation = dilation
    else:
        raise ValueError('Invalid dilation size.')
    
    # Implementing ConvTranspose2d using Conv2d with weight-time reversal.
    self.ConvTranspose2d = nn.Conv2d(in_channels = in_channels, 
                                    out_channels = out_channels,
                                    kernel_size = kernel_size,
                                    stride = (self.T_stride, 1), # An additional upsampling will be used in forward, if F_stride != 1
                                    padding = (self.T_pad, 0),   # An additional padding will be used in forward, if F_pad != 0
                                    dilation = dilation,
                                    groups = groups,
                                    bias = bias)
    
def forward(self, x, cache=None):
    """
    x: [bs,C,1,F]
    cache: [bs,C,T-1,F] (optional)
    """
    inp = []
    if cache is None:
        inp = x
        return_cache = False
    else:
        # [bs,C,T,F]
        inp = torch.cat([cache, x], dim = 2)
        out_cache = inp[:, :, 1:]
        return_cache = True
    
    bs, C, T, F = inp.shape
    
    # Upsampling operation
    if self.F_stride > 1: 
        # [bs,C,T,F] -> [bs,C,T,F,1] -> [bs,C,T,F,F_stride] -> [bs,C,T,F_out]
        inp = torch.cat([inp[:,:,:,:,None], torch.zeros([bs,C,T,F,self.F_stride-1], device=x.device, dtype=x.dtype)], dim = -1).reshape([bs,C,T,-1])
        left_pad = self.F_stride - 1
        if self.F_size > 1:
            if left_pad <= self.F_size - 1:
                inp = torch.nn.functional.pad(inp, pad = [(self.F_size - 1)*self.F_dilation-self.F_pad, (self.F_size - 1)*self.F_dilation-self.F_pad - left_pad, 0, 0])
            else:
                # inp = torch.nn.functional.pad(inp, pad = [self.F_size - 1, 0, 0, 0])[:,:,:,: - (left_pad - self.F_stride + 1)]
                raise(NotImplementedError)
        else:
            # inp = inp[:,:,:,:-left_pad]
            raise(NotImplementedError)

    else: # F_stride = 1
        inp = torch.nn.functional.pad(inp, pad=[(self.F_size-1)*self.F_dilation-self.F_pad, (self.F_size-1)*self.F_dilation-self.F_pad])
            
    outp = self.ConvTranspose2d(inp)
    if return_cache:
        return outp, out_cache
    else:
        return outp

` Test code is same code, and It is the following:

`in_channels = 32 hidden_channels = 16

original_layer = nn.ConvTranspose2d( in_channels=in_channels, out_channels=hidden_channels, # kernel_size = 1, kernel_size=(1,5), stride=(1,2), padding=(0,2), groups= 1

)

equivalent_conv = StreamConvTranspose2d( in_channels=in_channels, out_channels=hidden_channels, # kernel_size = 1, kernel_size=(1,5), stride=(1,2), padding=(0,2), groups= 1 )

x = torch.randn(1, in_channels, 16, 128) convert_to_stream(equivalent_conv, original_layer)

y_original = original_layer(x) y_equivalent = equivalent_conv(x) print(np.shape(y_original)) print(np.shape(y_equivalent)) assert torch.allclose(y_original, y_equivalent, atol=1e-6)`

i found groups =2 it is error, but groups = 1 ,it is not error.

YIN-jd avatar Sep 02 '25 02:09 YIN-jd