TorchLRP icon indicating copy to clipboard operation
TorchLRP copied to clipboard

Update the conv_transpose2d usage?

Open sdw95927 opened this issue 3 years ago • 3 comments

I implemented the excellent scripts and found that conv_transpose2d does not work properly for my own work. So I updated it in functional.conv line27 as follows:

    # relevance_input  = F.conv_transpose2d(relevance_output, weight, None, padding=1)
    if ctx.stride[0] >= 2:
        output_padding = 1
    else:
        output_padding = 0
    relevance_input  = F.conv_transpose2d(relevance_output, weight, None, stride=ctx.stride, padding=ctx.padding, output_padding=output_padding)

and also here:

        def f(X1, X2, W1, W2, ctx): 

            # Z1  = F.conv2d(X1, W1, bias=None, stride=1, padding=1) 
            # Z2  = F.conv2d(X2, W2, bias=None, stride=1, padding=1)
            Z1 = F.conv2d(X1, W1, None, ctx.stride, ctx.padding, ctx.dilation, ctx.groups)
            Z2 = F.conv2d(X2, W2, None, ctx.stride, ctx.padding, ctx.dilation, ctx.groups)
            Z   = Z1 + Z2

            rel_out = relevance_output / (Z + (Z==0).float()* 1e-6)

            # t1 = F.conv_transpose2d(rel_out, W1, bias=None, padding=1) 
            # t2 = F.conv_transpose2d(rel_out, W2, bias=None, padding=1)
            if ctx.stride[0] >= 2:
                output_padding = 1
            else:
                output_padding = 0
            t1 = F.conv_transpose2d(rel_out, W1, None, stride=ctx.stride, padding=ctx.padding, output_padding=output_padding)
            t2 = F.conv_transpose2d(rel_out, W2, None, stride=ctx.stride, padding=ctx.padding, output_padding=output_padding)

            r1  = t1 * X1
            r2  = t2 * X2

            return r1 + r2

Not sure if this is my own issue, but the above change fixed my problem.

sdw95927 avatar Apr 30 '21 06:04 sdw95927

Thank you @sdw95927 for posting this. I just caught an error 'Conv2DAlpha1Beta0Backward' object has no attribute 'stride'. I am wondering where should I set the stride for the ctx after making mentioned changes.

miladsikaroudi avatar Apr 25 '22 15:04 miladsikaroudi

Thank you @sdw95927 for posting this. I just caught an error 'Conv2DAlpha1Beta0Backward' object has no attribute 'stride'. I am wondering where should I set the stride for the ctx after making mentioned changes.

Can you take a screen shot of both your script and error?

sdw95927 avatar Apr 25 '22 16:04 sdw95927

The problem is fixed. Thank you so much.

miladsikaroudi avatar Apr 25 '22 17:04 miladsikaroudi