ReverseDiff.jl icon indicating copy to clipboard operation
ReverseDiff.jl copied to clipboard

Cannot compute derivatives of quadratic form.

Open bicycle1885 opened this issue 7 years ago • 4 comments

I tried to compute the gradient of a function with quadratic form. However, it failed with ambiguity error as follows:

import ReverseDiff
const A = [1.0 2.0; 2.0 5.0]
quadratic(x) = x' * A * x
ReverseDiff.gradient(quadratic, ones(2))
MethodError: *(::RowVector{ReverseDiff.TrackedReal{Float64,Float64,ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}},ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}}, ::ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}) is ambiguous. Candidates:
  *(x::AbstractArray{T,2} where T, y::ReverseDiff.TrackedArray{V,D,N,VA,DA} where DA where VA where N) where {V, D} in ReverseDiff at /Users/kenta/.julia/v0.6/ReverseDiff/src/derivatives/linalg/arithmetic.jl:193
  *(x::AbstractArray, y::ReverseDiff.TrackedArray{V,D,N,VA,DA} where DA where VA where N) where {V, D} in ReverseDiff at /Users/kenta/.julia/v0.6/ReverseDiff/src/derivatives/linalg/arithmetic.jl:193
  *(rowvec::RowVector{T,V} where V<:(AbstractArray{T,1} where T), vec::AbstractArray{T,1}) where T<:Real in Base.LinAlg at linalg/rowvector.jl:170
Possible fix, define
  *(::RowVector{ReverseDiff.TrackedReal{V,D,ReverseDiff.TrackedArray{V,D,1,VA,DA}},V} where V<:(AbstractArray{T,1} where T), ::ReverseDiff.TrackedArray{V,D,1,VA,DA})

Stacktrace:
 [1] * at ./operators.jl:424 [inlined]
 [2] quadratic(::ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}) at ./In[32]:2
 [3] Type at /Users/kenta/.julia/v0.6/ReverseDiff/src/api/tape.jl:199 [inlined]
 [4] gradient(::Function, ::Array{Float64,1}, ::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}}) at /Users/kenta/.julia/v0.6/ReverseDiff/src/api/gradients.jl:22 (repeats 2 times)

I think matrix multiplication is already supported. I'm not sure whether it is an unsupported feature or a kind of bug, so let me file an issue here.

I'm using ReverseDiff.jl v0.1.4 on Julia 0.6.

bicycle1885 avatar Jul 01 '17 22:07 bicycle1885

I don't know whether it tackles the problem as you'd have wished, but there are a few ad hoc solutions to that: First, compute the vector * matrix product using the sum of the element-wise product

julia> ReverseDiff.gradient(x->sum(x.*A,1) * x,x)
2-element Array{Float64,1}:
  4.57215
 10.7736

The other possibility could be to compute the products separately

julia> ReverseDiff.gradient(x->Array(x.'*A) * x,x)
2-element Array{Float64,1}:
  4.57215
 10.7736 

Without the Array() the product returns a RowVector{Float64,Array{Float64,1}} which makes ReverseDiff unhappy...

vmoens avatar Jul 07 '17 10:07 vmoens

This is the classic "new array types which define ambiguous method definitions" problem. It might be fixed just by adding :RowVector to the ambiguity list.

jrevels avatar Jul 07 '17 16:07 jrevels

@oViMo Thanks. I'm avoiding the problem with similar techniques you suggested.

@jrevels I've tried your suggestion but it has still ambiguity problem.

diff --git a/src/ReverseDiff.jl b/src/ReverseDiff.jl
index 6b0635c..e472e73 100644
--- a/src/ReverseDiff.jl
+++ b/src/ReverseDiff.jl
@@ -21,7 +21,7 @@ end
 
 # Not all operations will be valid over all of these types, but that's okay; such cases
 # will simply error when they hit the original operation in the overloaded definition.
-const ARRAY_TYPES = (:AbstractArray, :AbstractVector, :AbstractMatrix, :Array, :Vector, :Matrix)
+const ARRAY_TYPES = (:AbstractArray, :AbstractVector, :AbstractMatrix, :Array, :Vector, :Matrix, :RowVector)
 const REAL_TYPES = (:Bool, :Integer, :Rational, :BigFloat, :BigInt, :AbstractFloat, :Real, :Dual)
 
 const FORWARD_UNARY_SCALAR_FUNCS = (ForwardDiff.AUTO_DEFINED_UNARY_FUNCS..., :-, :abs, :conj)
julia> using ReverseDiff
INFO: Recompiling stale cache file /Users/kenta/.julia/lib/v0.6/ReverseDiff.ji for module ReverseDiff.

julia> const A = [1.0 2.0; 2.0 5.0]
2×2 Array{Float64,2}:
 1.0  2.0
 2.0  5.0

julia> quadratic(x) = x' * A * x
quadratic (generic function with 1 method)

julia> ReverseDiff.gradient(quadratic, ones(2))
ERROR: MethodError: *(::RowVector{ReverseDiff.TrackedReal{Float64,Float64,ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}},ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}}, ::ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}) is ambiguous. Candidates:
  *(x::RowVector, y::ReverseDiff.TrackedArray{V,D,N,VA,DA} where DA where VA where N) where {V, D} in ReverseDiff at /Users/kenta/.julia/v0.6/ReverseDiff/src/derivatives/linalg/arithmetic.jl:193
  *(x::AbstractArray{T,2} where T, y::ReverseDiff.TrackedArray{V,D,N,VA,DA} where DA where VA where N) where {V, D} in ReverseDiff at /Users/kenta/.julia/v0.6/ReverseDiff/src/derivatives/linalg/arithmetic.jl:193
  *(x::AbstractArray, y::ReverseDiff.TrackedArray{V,D,N,VA,DA} where DA where VA where N) where {V, D} in ReverseDiff at /Users/kenta/.julia/v0.6/ReverseDiff/src/derivatives/linalg/arithmetic.jl:193
  *(rowvec::RowVector{T,V} where V<:(AbstractArray{T,1} where T), vec::AbstractArray{T,1}) where T<:Real in Base.LinAlg at linalg/rowvector.jl:170
Possible fix, define
  *(::RowVector{ReverseDiff.TrackedReal{V,D,ReverseDiff.TrackedArray{V,D,1,VA,DA}},V} where V<:(AbstractArray{T,1} where T), ::ReverseDiff.TrackedArray{V,D,1,VA,DA})
Stacktrace:
 [1] * at ./operators.jl:424 [inlined]
 [2] quadratic(::ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}) at ./REPL[3]:1
 [3] Type at /Users/kenta/.julia/v0.6/ReverseDiff/src/api/tape.jl:199 [inlined]
 [4] gradient(::Function, ::Array{Float64,1}, ::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float64,Float64,1,Array{Float64,1},Array{Float64,1}}}) at /Users/kenta/.julia/v0.6/ReverseDiff/src/api/gradients.jl:22 (repeats 2 times)

bicycle1885 avatar Jul 08 '17 10:07 bicycle1885

This no longer gives an error. But if A is not symmetric, it gives wrong results, as seen in this discourse thread:

julia> const A3 = [1.0 2.0; 7.0 5.0];

julia> quadratic3(x) = x' * A3 * x;

julia> ReverseDiff.gradient(quadratic3, ones(2))  # wrong
2-element Vector{Float64}:
 16.0
 14.0

julia> ForwardDiff.gradient(quadratic3, ones(2))
2-element Vector{Float64}:
 11.0
 19.0

julia> Zygote.gradient(quadratic3, ones(2))[1]
2-element Vector{Float64}:
 11.0
 19.0

julia> ReverseDiff.gradient(x -> dot(x, A3, x), ones(2))  # dot works
2-element Vector{Float64}:
 11.0
 19.0

(jl_N4cJfW) pkg> st
...
  [37e2e3b7] ReverseDiff v1.14.1

mcabbott avatar Jul 04 '22 21:07 mcabbott