equinox icon indicating copy to clipboard operation
equinox copied to clipboard

[Feature Request] Allow to specify `Conv` padding by strings

Open ChenAo-Phys opened this issue 1 year ago • 9 comments

Hello! I hope to specify the padding in Conv by strings like 'SAME' or 'VALID'. This is compatible with jax API since it's already implemented in jax.lax.conv_general_dilated.

Furthermore, I also would like to have a 'CIRCULAR' padding which is not in jax.lax.conv_general_dilated but implemented in flax and pytorch. In the 1D case, for instance, this padding behaves as x[L + i] = x[i] so that the chain becomes effectively a ring. This is particularly helpful in some physical systems.

I can propose a PR if you think these changes are suitable. Thanks!

ChenAo-Phys avatar Jan 11 '24 15:01 ChenAo-Phys

I think that sounds reasonable to me. :) I'd be happy to take a pull request adding these options.

Do you know how the circular option might be implemented efficiently? Done naively I suspect we might end up copying the input.

patrick-kidger avatar Jan 21 '24 13:01 patrick-kidger

I think that sounds reasonable to me. :) I'd be happy to take a pull request adding these options.

Do you know how the circular option might be implemented efficiently? Done naively I suspect we might end up copying the input.

As far as I know, there is no "clever" way to do circular padding other than copying the input, probably because the circular padding is not internally supported in cudnn. That's also how other packages do it.

As an aside, would you like to also accept string padding in ConvTranspose and Pool? This will make the API more self-consistent.

ChenAo-Phys avatar Jan 22 '24 13:01 ChenAo-Phys

As far as I know, there is no "clever" way to do circular padding other than copying the input, probably because the circular padding is not internally supported in cudnn. That's also how other packages do it.

Ah well. Makes sense.

As an aside, would you like to also accept string padding in ConvTranspose and Pool? This will make the API more self-consistent.

Absolutely! That'd be great for consistency.

patrick-kidger avatar Jan 23 '24 00:01 patrick-kidger

I have been testing the new codes these days and it works quite well. There are some details I want to clarify before submitting the PR.

  • Conv "CIRCULAR" is actually a "padding mode" specifying how to fill padding values instead of how many additional units should be appended to the edge. In pytorch they have two different arguments, padding and padding mode, but in our case this will complicate the API. So I define "CIRCULAR" to be the same amount of padding as "SAME" filled with circular values instead of zeros, which is a suitable choice in most physical scenarios.

  • ConvTranspose I want to make sure that when padding is given by a string, ConvTranspose still produces shape and connectivity compatible with the Conv which has the same padding. The key problem here is how to define output_padding. For a constant amount of padding, Relation 15 in this paper says o = floor((i + 2*p - d*(k-1) - 1) / s) + 1. output_padding can be defined as a quantity a that eliminates the floor by o = (i - a + 2*p - d*(k-1) - 1) / s + 1. "SAME" always makes a suitable amount of padding so that the floor doesn't appear, but the amount of padding p becomes uncertain. Now we can consider this relation in "SAME" padding o = ceil(i / s), then output_padding can be defined by o = (i + a) / s. So in ConvTranspose one can obtain i from o, and p can be obtained from o = (i + 2*p - d*(k-1) - 1) / s + 1.

  • Pool The argument use_ceil looks unnecessary now, because "SAME" means "ceil" is used to compute the final output shape. The only difference is that "SAME" pads on the head and tail while "ceil" only makes additional padding on the tail. I leave use_ceil there for backward compatibility of API, but we could consider removing it later.

ChenAo-Phys avatar Jan 24 '24 15:01 ChenAo-Phys

  • Conv: I think I like PyTorch's approach here. We could have padding_mode="zeros"/"circular". Basically, this seems slightly more general and I don't think we lose anything (including readability) by going this way.
  • ConvTranspose: I don't think there should be any problem here. Internally I'm imagining the code looks something like this:
    def _padding_string_to_int(x: str):
        if x.lower() == "same":
            return ...
        elif x.lower() == "valid":
            return 0
        else:
            raise ValueError(...)
    
    class Conv:
        def __init__(self, padding, ...):
            if isinstance(padding, str):
                padding = _padding_string_to_int(padding)
            ...
    
    class ConvTranspose:
        def __init__(self, padding, ...):
            if isinstance(padding, str):
                padding = _padding_string_to_int(padding)
            ...
    
    so that supporting these extra modes means very little extra work.
  • Pool: you say because "SAME" means "ceil" is used to compute the final output shape but I don't think that's actually the case: "SAME" means that the input shape determines the output shape directly, and use_ceil is just completely superfluous? I think if padding="SAME" then here I'd just suggest ignoring the value of use_ceil entirely. I don't think I'd deprecate it though, it still does something useful for more general padding scenarios.

patrick-kidger avatar Jan 26 '24 20:01 patrick-kidger

  • Conv: I think I like PyTorch's approach here. We could have padding_mode="zeros"/"circular". Basically, this seems slightly more general and I don't think we lose anything (including readability) by going this way.

OK that's nice. Initially I just wanted to make minimal changes, but now I'm glad to adopt the PyTorch approach with your consent.

  • ConvTranspose: I don't think there should be any problem here. Internally I'm imagining the code looks something like this:
    def _padding_string_to_int(x: str):
        if x.lower() == "same":
            return ...
        elif x.lower() == "valid":
            return 0
        else:
            raise ValueError(...)
    
    class Conv:
        def __init__(self, padding, ...):
            if isinstance(padding, str):
                padding = _padding_string_to_int(padding)
            ...
    
    class ConvTranspose:
        def __init__(self, padding, ...):
            if isinstance(padding, str):
                padding = _padding_string_to_int(padding)
            ...
    
    so that supporting these extra modes means very little extra work.

This code doesn't work, because "SAME" means the exact padding number depends on the input shape which we don't know during initialization. We can only have self.padding = "SAME", and then determine the exact padding number in the forward pass.

This is an even more severe problem for ConvTranspose. For instance, a 1D conv layer has padding="SAME", kernel_size=2, stride=2, and out_shape = 2, it can either be (x means padding, |---| means kernel window) 0 1 2 x (in_shape = 3, padding = 1) |--| |--| or 0 1 2 3 (in_shape = 4, padding = 0) |--| |--|

So for ConvTranspose it becomes a non-trivial problem to determine the padding number. In PyTorch they just give up supporting "SAME" padding in ConvTranspose. But this can be solved by the method I mentioned above.

  • Pool: you say because "SAME" means "ceil" is used to compute the final output shape but I don't think that's actually the case: "SAME" means that the input shape determines the output shape directly, and use_ceil is just completely superfluous? I think if padding="SAME" then here I'd just suggest ignoring the value of use_ceil entirely. I don't think I'd deprecate it though, it still does something useful for more general padding scenarios.

OK, I don't mind leaving use_ceil there

ChenAo-Phys avatar Jan 26 '24 23:01 ChenAo-Phys

Ah, good point about SAME having to be a string that's resolved at call-time.

For ConvTranspose, I think I agree. We know exactly what output shape we want (the same as the input), so with a bit of algebra we should be able to figure out the value of output_padding that gives the desired shape.

patrick-kidger avatar Jan 29 '24 21:01 patrick-kidger

Hi,

I would also find more options for padding useful. Especially the periodic/circular padding could be of great usage for my applications. Personally, I also prefer the additional keyword argument padding_mode. This could also be a great opportunity to let equinox have periodic/circular padding in transpose convolutions that is not yet available in PyTorch but relevant for UNets or Decoder Architectures in general, when applied to periodic fields. How is the progress with the PR @ChenAo-Phys ? I would be very happy to help 😊.

Ceyron avatar Feb 14 '24 16:02 Ceyron

Hi,

I would also find more options for padding useful. Especially the periodic/circular padding could be of great usage for my applications. Personally, I also prefer the additional keyword argument padding_mode. This could also be a great opportunity to let equinox have periodic/circular padding in transpose convolutions that is not yet available in PyTorch but relevant for UNets or Decoder Architectures in general, when applied to periodic fields. How is the progress with the PR @ChenAo-Phys ? I would be very happy to help 😊.

It has been almost ready. I happened to have some other things to do two weeks ago, so the PR is delayed. I will submit a PR soon, probably this week.

ChenAo-Phys avatar Feb 14 '24 17:02 ChenAo-Phys