NNlib.jl icon indicating copy to clipboard operation
NNlib.jl copied to clipboard

Complex{Float32} is not a subtype of Float32, breaking conv of complex valued input/weights

Open dsweber2 opened this issue 4 years ago • 3 comments

Just a simple convolution with two complex matrices breaks: conv(randn(Complex{Float32}, 25,25,1,3), randn(Complex{Float32}, 5,5,1,3))

The stated reason is that this line https://github.com/FluxML/NNlib.jl/blob/0d16973bab1260de045c1599ec3d12e5adac1d70/src/nnpack/interface.jl#L15 is trying to coerce a complex into a real by converting it to a Float32. Tweaking both https://github.com/FluxML/NNlib.jl/blob/0d16973bab1260de045c1599ec3d12e5adac1d70/src/nnpack/impl.jl#L8 and https://github.com/FluxML/NNlib.jl/blob/0d16973bab1260de045c1599ec3d12e5adac1d70/src/nnpack/libnnpack.jl#L127 to allow for Complex{Float32} as well allowed it to work as normal. I suspect a similar issue happens with any other functions that could work on complex inputs, but I haven't done extensive testing.

dsweber2 avatar Jun 12 '20 19:06 dsweber2

We are likely removing the forced conversion in #212. That said I don't think NNPACK has Complex support, right? We could support complex convolution in conv_im2col I guess

CarloLucibello avatar Jun 17 '20 08:06 CarloLucibello

OP's code works now on a CPU but fails when I try to do this on the GPU. I'd like this functionality but as a relatively new programmer I'm not sure where to begin in order to implement this enhancement.

aksuhton avatar Jun 14 '23 17:06 aksuhton

Flux's GPU support for convs currently relies on cuDNN, which to my knowledge doesn't support complex numbers. If you can figure out a way to make it support them or to write GPU-friendly conv kernels we can use as a substitute, those would be the main paths for moving this forward.

ToucheSir avatar Jun 14 '23 19:06 ToucheSir