KernelFunctions.jl icon indicating copy to clipboard operation
KernelFunctions.jl copied to clipboard

Differentiating `FunctionTransform` with Zygote

Open molet opened this issue 5 years ago • 1 comments

I guess this might be another Zygote related issue: differentiating FunctionTransform doesn't work for multidimensional inputs:

julia> x = rand(2, 100);
julia> k(θ) = TransformedKernel(TransformedKernel(ExponentialKernel(), FunctionTransform(x->θ[1]*x)), SelectTransform([1]));
julia> o(θ) = sum(kernelmatrix(k(θ), x));
julia> o([2.3])
5636.35135777359
julia> Zygote.gradient(o, [2.3])
ERROR: DimensionMismatch("cannot broadcast array to have fewer dimensions")
Stacktrace:
 [1] check_broadcast_shape(::Tuple{}, ::Tuple{Base.OneTo{Int64}}) at ./broadcast.jl:518
 [2] check_broadcast_shape(::Tuple{Base.OneTo{Int64}}, ::Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}) at ./broadcast.jl:521
 [3] check_broadcast_axes at ./broadcast.jl:523 [inlined]
 [4] check_broadcast_axes at ./broadcast.jl:527 [inlined]
 [5] instantiate at ./broadcast.jl:269 [inlined]
 [6] materialize! at ./broadcast.jl:848 [inlined]
 [7] materialize! at ./broadcast.jl:845 [inlined]
 [8] (::Zygote.var"#347#349"{SubArray{Float64,2,Array{Float64,2},Tuple{Array{Int64,1},Base.Slice{Base.OneTo{Int64}}},false},Tuple{Colon,Int64}})(::Array{Float64,2}) at /Users/molet/.julia/packages/Zygote/NSWXO/src/lib/array.jl:42
 [9] (::Zygote.var"#2199#back#345"{Zygote.var"#347#349"{SubArray{Float64,2,Array{Float64,2},Tuple{Array{Int64,1},Base.Slice{Base.OneTo{Int64}}},false},Tuple{Colon,Int64}}})(::Array{Float64,2}) at /Users/molet/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [10] #10 at /Users/molet/.julia/packages/KernelFunctions/6cGns/src/transform/functiontransform.jl:25 [inlined]
 [11] (::typeof(∂(λ)))(::Array{Float64,2}) at /Users/molet/.julia/packages/Zygote/NSWXO/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#493#497")(::typeof(∂(λ)), ::Array{Float64,2}) at /Users/molet/.julia/packages/Zygote/NSWXO/src/lib/array.jl:187
 [13] (::Base.var"#3#4"{Zygote.var"#493#497"})(::Tuple{typeof(∂(λ)),Array{Float64,2}}) at ./generator.jl:36
 [14] iterate at ./generator.jl:47 [inlined]
 [15] collect at ./array.jl:686 [inlined]
 [16] map at ./abstractarray.jl:2248 [inlined]
 [17] (::Zygote.var"#492#496"{Array{typeof(∂(λ)),1}})(::Array{Array{Float64,2},1}) at /Users/molet/.julia/packages/Zygote/NSWXO/src/lib/array.jl:187
 [18] (::Zygote.var"#2515#back#498"{Zygote.var"#492#496"{Array{typeof(∂(λ)),1}}})(::Array{Array{Float64,2},1}) at /Users/molet/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [19] _map at /Users/molet/.julia/packages/KernelFunctions/6cGns/src/transform/functiontransform.jl:24 [inlined]
 [20] (::typeof(∂(_map)))(::NamedTuple{(:X,),Tuple{Array{Float64,2}}}) at /Users/molet/.julia/packages/Zygote/NSWXO/src/compiler/interface2.jl:0
 [21] kernelmatrix at /Users/molet/.julia/packages/KernelFunctions/6cGns/src/kernels/transformedkernel.jl:89 [inlined]
 [22] (::typeof(∂(kernelmatrix)))(::FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}) at /Users/molet/.julia/packages/Zygote/NSWXO/src/compiler/interface2.jl:0
 [23] kernelmatrix at /Users/molet/.julia/packages/KernelFunctions/6cGns/src/kernels/transformedkernel.jl:89 [inlined]
 [24] #kernelmatrix#93 at /Users/molet/.julia/packages/KernelFunctions/6cGns/src/matrix/kernelmatrix.jl:117 [inlined]
 [25] (::typeof(∂(#kernelmatrix#93)))(::FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}) at /Users/molet/.julia/packages/Zygote/NSWXO/src/compiler/interface2.jl:0
 [26] kernelmatrix at /Users/molet/.julia/packages/KernelFunctions/6cGns/src/matrix/kernelmatrix.jl:117 [inlined]
 [27] (::typeof(∂(kernelmatrix)))(::FillArrays.Fill{Float64,2,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}}) at /Users/molet/.julia/packages/Zygote/NSWXO/src/compiler/interface2.jl:0
 [28] o at ./REPL[34]:1 [inlined]
 [29] (::typeof(∂(o)))(::Float64) at /Users/molet/.julia/packages/Zygote/NSWXO/src/compiler/interface2.jl:0
 [30] (::Zygote.var"#41#42"{typeof(∂(o))})(::Float64) at /Users/molet/.julia/packages/Zygote/NSWXO/src/compiler/interface.jl:45
 [31] gradient(::Function, ::Array{Float64,1}) at /Users/molet/.julia/packages/Zygote/NSWXO/src/compiler/interface.jl:54
 [32] top-level scope at REPL[39]:1

I've got similar error message for the following as well:

julia> x = rand(1, 100);
julia> k(θ) = TransformedKernel(ExponentialKernel(), FunctionTransform(x->θ[1]*x));
julia> o(θ) = sum(kernelmatrix(k(θ), x));

However, for vector input it seems to be working:

julia> x = rand(100);
julia> k(θ) = TransformedKernel(ExponentialKernel(), FunctionTransform(x->θ[1]*x));
julia> o(θ) = sum(kernelmatrix(k(θ), x));
julia> o([2.3])
5588.952135766347
julia> Zygote.gradient(o, [2.3])
([-1182.066704150817],)

molet avatar Sep 22 '20 09:09 molet

@molet and I had a chat about this.

This is definitely an issue, but for the time being #173 is sufficient for his needs, so we're going to push that forwards.

I'll leave this issue open as someone should probably try to resolve it at some point.

willtebbutt avatar Sep 22 '20 18:09 willtebbutt