ReverseDiff.jl
ReverseDiff.jl copied to clipboard
Cannot compute derivatives of quadratic form.
I tried to compute the gradient of a function with quadratic form. However, it failed with ambiguity error as follows:
import ReverseDiff
const A = [1.0 2.0; 2.0 5.0]
quadratic(x) = x' * A * x
ReverseDiff.gradient(quadratic, ones(2))
MethodError: *(::RowVector{ReverseDiff.TrackedReal{Float64,Float64,ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}},ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}}, ::ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}) is ambiguous. Candidates:
*(x::AbstractArray{T,2} where T, y::ReverseDiff.TrackedArray{V,D,N,VA,DA} where DA where VA where N) where {V, D} in ReverseDiff at /Users/kenta/.julia/v0.6/ReverseDiff/src/derivatives/linalg/arithmetic.jl:193
*(x::AbstractArray, y::ReverseDiff.TrackedArray{V,D,N,VA,DA} where DA where VA where N) where {V, D} in ReverseDiff at /Users/kenta/.julia/v0.6/ReverseDiff/src/derivatives/linalg/arithmetic.jl:193
*(rowvec::RowVector{T,V} where V<:(AbstractArray{T,1} where T), vec::AbstractArray{T,1}) where T<:Real in Base.LinAlg at linalg/rowvector.jl:170
Possible fix, define
*(::RowVector{ReverseDiff.TrackedReal{V,D,ReverseDiff.TrackedArray{V,D,1,VA,DA}},V} where V<:(AbstractArray{T,1} where T), ::ReverseDiff.TrackedArray{V,D,1,VA,DA})
Stacktrace:
[1] * at ./operators.jl:424 [inlined]
[2] quadratic(::ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}) at ./In[32]:2
[3] Type at /Users/kenta/.julia/v0.6/ReverseDiff/src/api/tape.jl:199 [inlined]
[4] gradient(::Function, ::Array{Float64,1}, ::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}}) at /Users/kenta/.julia/v0.6/ReverseDiff/src/api/gradients.jl:22 (repeats 2 times)
I think matrix multiplication is already supported. I'm not sure whether it is an unsupported feature or a kind of bug, so let me file an issue here.
I'm using ReverseDiff.jl v0.1.4 on Julia 0.6.
I don't know whether it tackles the problem as you'd have wished, but there are a few ad hoc solutions to that: First, compute the vector * matrix product using the sum of the element-wise product
julia> ReverseDiff.gradient(x->sum(x.*A,1) * x,x)
2-element Array{Float64,1}:
4.57215
10.7736
The other possibility could be to compute the products separately
julia> ReverseDiff.gradient(x->Array(x.'*A) * x,x)
2-element Array{Float64,1}:
4.57215
10.7736
Without the Array()
the product returns a RowVector{Float64,Array{Float64,1}}
which makes ReverseDiff unhappy...
This is the classic "new array types which define ambiguous method definitions" problem. It might be fixed just by adding :RowVector
to the ambiguity list.
@oViMo Thanks. I'm avoiding the problem with similar techniques you suggested.
@jrevels I've tried your suggestion but it has still ambiguity problem.
diff --git a/src/ReverseDiff.jl b/src/ReverseDiff.jl
index 6b0635c..e472e73 100644
--- a/src/ReverseDiff.jl
+++ b/src/ReverseDiff.jl
@@ -21,7 +21,7 @@ end
# Not all operations will be valid over all of these types, but that's okay; such cases
# will simply error when they hit the original operation in the overloaded definition.
-const ARRAY_TYPES = (:AbstractArray, :AbstractVector, :AbstractMatrix, :Array, :Vector, :Matrix)
+const ARRAY_TYPES = (:AbstractArray, :AbstractVector, :AbstractMatrix, :Array, :Vector, :Matrix, :RowVector)
const REAL_TYPES = (:Bool, :Integer, :Rational, :BigFloat, :BigInt, :AbstractFloat, :Real, :Dual)
const FORWARD_UNARY_SCALAR_FUNCS = (ForwardDiff.AUTO_DEFINED_UNARY_FUNCS..., :-, :abs, :conj)
julia> using ReverseDiff
INFO: Recompiling stale cache file /Users/kenta/.julia/lib/v0.6/ReverseDiff.ji for module ReverseDiff.
julia> const A = [1.0 2.0; 2.0 5.0]
2×2 Array{Float64,2}:
1.0 2.0
2.0 5.0
julia> quadratic(x) = x' * A * x
quadratic (generic function with 1 method)
julia> ReverseDiff.gradient(quadratic, ones(2))
ERROR: MethodError: *(::RowVector{ReverseDiff.TrackedReal{Float64,Float64,ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}},ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}}, ::ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}) is ambiguous. Candidates:
*(x::RowVector, y::ReverseDiff.TrackedArray{V,D,N,VA,DA} where DA where VA where N) where {V, D} in ReverseDiff at /Users/kenta/.julia/v0.6/ReverseDiff/src/derivatives/linalg/arithmetic.jl:193
*(x::AbstractArray{T,2} where T, y::ReverseDiff.TrackedArray{V,D,N,VA,DA} where DA where VA where N) where {V, D} in ReverseDiff at /Users/kenta/.julia/v0.6/ReverseDiff/src/derivatives/linalg/arithmetic.jl:193
*(x::AbstractArray, y::ReverseDiff.TrackedArray{V,D,N,VA,DA} where DA where VA where N) where {V, D} in ReverseDiff at /Users/kenta/.julia/v0.6/ReverseDiff/src/derivatives/linalg/arithmetic.jl:193
*(rowvec::RowVector{T,V} where V<:(AbstractArray{T,1} where T), vec::AbstractArray{T,1}) where T<:Real in Base.LinAlg at linalg/rowvector.jl:170
Possible fix, define
*(::RowVector{ReverseDiff.TrackedReal{V,D,ReverseDiff.TrackedArray{V,D,1,VA,DA}},V} where V<:(AbstractArray{T,1} where T), ::ReverseDiff.TrackedArray{V,D,1,VA,DA})
Stacktrace:
[1] * at ./operators.jl:424 [inlined]
[2] quadratic(::ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}) at ./REPL[3]:1
[3] Type at /Users/kenta/.julia/v0.6/ReverseDiff/src/api/tape.jl:199 [inlined]
[4] gradient(::Function, ::Array{Float64,1}, ::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}}) at /Users/kenta/.julia/v0.6/ReverseDiff/src/api/gradients.jl:22 (repeats 2 times)
This no longer gives an error. But if A
is not symmetric, it gives wrong results, as seen in this discourse thread:
julia> const A3 = [1.0 2.0; 7.0 5.0];
julia> quadratic3(x) = x' * A3 * x;
julia> ReverseDiff.gradient(quadratic3, ones(2)) # wrong
2-element Vector{Float64}:
16.0
14.0
julia> ForwardDiff.gradient(quadratic3, ones(2))
2-element Vector{Float64}:
11.0
19.0
julia> Zygote.gradient(quadratic3, ones(2))[1]
2-element Vector{Float64}:
11.0
19.0
julia> ReverseDiff.gradient(x -> dot(x, A3, x), ones(2)) # dot works
2-element Vector{Float64}:
11.0
19.0
(jl_N4cJfW) pkg> st
...
[37e2e3b7] ReverseDiff v1.14.1