NNlib.jl
NNlib.jl copied to clipboard
Complex{Float32} is not a subtype of Float32, breaking conv of complex valued input/weights
Just a simple convolution with two complex matrices breaks:
conv(randn(Complex{Float32}, 25,25,1,3), randn(Complex{Float32}, 5,5,1,3))
The stated reason is that this line
https://github.com/FluxML/NNlib.jl/blob/0d16973bab1260de045c1599ec3d12e5adac1d70/src/nnpack/interface.jl#L15 is trying to coerce a complex into a real by converting it to a Float32. Tweaking both
https://github.com/FluxML/NNlib.jl/blob/0d16973bab1260de045c1599ec3d12e5adac1d70/src/nnpack/impl.jl#L8
and
https://github.com/FluxML/NNlib.jl/blob/0d16973bab1260de045c1599ec3d12e5adac1d70/src/nnpack/libnnpack.jl#L127
to allow for Complex{Float32}
as well allowed it to work as normal. I suspect a similar issue happens with any other functions that could work on complex inputs, but I haven't done extensive testing.
We are likely removing the forced conversion in #212. That said I don't think NNPACK has Complex support, right? We could support complex convolution in conv_im2col
I guess
OP's code works now on a CPU but fails when I try to do this on the GPU. I'd like this functionality but as a relatively new programmer I'm not sure where to begin in order to implement this enhancement.
Flux's GPU support for convs currently relies on cuDNN, which to my knowledge doesn't support complex numbers. If you can figure out a way to make it support them or to write GPU-friendly conv kernels we can use as a substitute, those would be the main paths for moving this forward.