KernelFunctions.jl
                                
                                 KernelFunctions.jl copied to clipboard
                                
                                    KernelFunctions.jl copied to clipboard
                            
                            
                            
                        Differentiating `FunctionTransform` with Zygote
I guess this might be another Zygote related issue: differentiating FunctionTransform doesn't work for multidimensional inputs:
julia> x = rand(2, 100);
julia> k(θ) = TransformedKernel(TransformedKernel(ExponentialKernel(), FunctionTransform(x->θ[1]*x)), SelectTransform([1]));
julia> o(θ) = sum(kernelmatrix(k(θ), x));
julia> o([2.3])
5636.35135777359
julia> Zygote.gradient(o, [2.3])
ERROR: DimensionMismatch("cannot broadcast array to have fewer dimensions")
Stacktrace:
 [1] check_broadcast_shape(::Tuple{}, ::Tuple{Base.OneTo{Int64}}) at ./broadcast.jl:518
 [2] check_broadcast_shape(::Tuple{Base.OneTo{Int64}}, ::Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}) at ./broadcast.jl:521
 [3] check_broadcast_axes at ./broadcast.jl:523 [inlined]
 [4] check_broadcast_axes at ./broadcast.jl:527 [inlined]
 [5] instantiate at ./broadcast.jl:269 [inlined]
 [6] materialize! at ./broadcast.jl:848 [inlined]
 [7] materialize! at ./broadcast.jl:845 [inlined]
 [8] (::Zygote.var"#347#349"{SubArray{Float64,2,Array{Float64,2},Tuple{Array{Int64,1},Base.Slice{Base.OneTo{Int64}}},false},Tuple{Colon,Int64}})(::Array{Float64,2}) at /Users/molet/.julia/packages/Zygote/NSWXO/src/lib/array.jl:42
 [9] (::Zygote.var"#2199#back#345"{Zygote.var"#347#349"{SubArray{Float64,2,Array{Float64,2},Tuple{Array{Int64,1},Base.Slice{Base.OneTo{Int64}}},false},Tuple{Colon,Int64}}})(::Array{Float64,2}) at /Users/molet/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [10] #10 at /Users/molet/.julia/packages/KernelFunctions/6cGns/src/transform/functiontransform.jl:25 [inlined]
 [11] (::typeof(∂(λ)))(::Array{Float64,2}) at /Users/molet/.julia/packages/Zygote/NSWXO/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#493#497")(::typeof(∂(λ)), ::Array{Float64,2}) at /Users/molet/.julia/packages/Zygote/NSWXO/src/lib/array.jl:187
 [13] (::Base.var"#3#4"{Zygote.var"#493#497"})(::Tuple{typeof(∂(λ)),Array{Float64,2}}) at ./generator.jl:36
 [14] iterate at ./generator.jl:47 [inlined]
 [15] collect at ./array.jl:686 [inlined]
 [16] map at ./abstractarray.jl:2248 [inlined]
 [17] (::Zygote.var"#492#496"{Array{typeof(∂(λ)),1}})(::Array{Array{Float64,2},1}) at /Users/molet/.julia/packages/Zygote/NSWXO/src/lib/array.jl:187
 [18] (::Zygote.var"#2515#back#498"{Zygote.var"#492#496"{Array{typeof(∂(λ)),1}}})(::Array{Array{Float64,2},1}) at /Users/molet/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [19] _map at /Users/molet/.julia/packages/KernelFunctions/6cGns/src/transform/functiontransform.jl:24 [inlined]
 [20] (::typeof(∂(_map)))(::NamedTuple{(:X,),Tuple{Array{Float64,2}}}) at /Users/molet/.julia/packages/Zygote/NSWXO/src/compiler/interface2.jl:0
 [21] kernelmatrix at /Users/molet/.julia/packages/KernelFunctions/6cGns/src/kernels/transformedkernel.jl:89 [inlined]
 [22] (::typeof(∂(kernelmatrix)))(::FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}) at /Users/molet/.julia/packages/Zygote/NSWXO/src/compiler/interface2.jl:0
 [23] kernelmatrix at /Users/molet/.julia/packages/KernelFunctions/6cGns/src/kernels/transformedkernel.jl:89 [inlined]
 [24] #kernelmatrix#93 at /Users/molet/.julia/packages/KernelFunctions/6cGns/src/matrix/kernelmatrix.jl:117 [inlined]
 [25] (::typeof(∂(#kernelmatrix#93)))(::FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}) at /Users/molet/.julia/packages/Zygote/NSWXO/src/compiler/interface2.jl:0
 [26] kernelmatrix at /Users/molet/.julia/packages/KernelFunctions/6cGns/src/matrix/kernelmatrix.jl:117 [inlined]
 [27] (::typeof(∂(kernelmatrix)))(::FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}) at /Users/molet/.julia/packages/Zygote/NSWXO/src/compiler/interface2.jl:0
 [28] o at ./REPL[34]:1 [inlined]
 [29] (::typeof(∂(o)))(::Float64) at /Users/molet/.julia/packages/Zygote/NSWXO/src/compiler/interface2.jl:0
 [30] (::Zygote.var"#41#42"{typeof(∂(o))})(::Float64) at /Users/molet/.julia/packages/Zygote/NSWXO/src/compiler/interface.jl:45
 [31] gradient(::Function, ::Array{Float64,1}) at /Users/molet/.julia/packages/Zygote/NSWXO/src/compiler/interface.jl:54
 [32] top-level scope at REPL[39]:1
I've got similar error message for the following as well:
julia> x = rand(1, 100);
julia> k(θ) = TransformedKernel(ExponentialKernel(), FunctionTransform(x->θ[1]*x));
julia> o(θ) = sum(kernelmatrix(k(θ), x));
However, for vector input it seems to be working:
julia> x = rand(100);
julia> k(θ) = TransformedKernel(ExponentialKernel(), FunctionTransform(x->θ[1]*x));
julia> o(θ) = sum(kernelmatrix(k(θ), x));
julia> o([2.3])
5588.952135766347
julia> Zygote.gradient(o, [2.3])
([-1182.066704150817],)
@molet and I had a chat about this.
This is definitely an issue, but for the time being #173 is sufficient for his needs, so we're going to push that forwards.
I'll leave this issue open as someone should probably try to resolve it at some point.