Distances.jl icon indicating copy to clipboard operation
Distances.jl copied to clipboard

`RenyiDivergence` incompatible with TrackedArray

Open torfjelde opened this issue 6 years ago • 6 comments

MVE

using Tracker, Distances
renyi_divergence(param([1.0]), param([1.0]), 0.5)

results in

MethodError: no method matching Float64(::Tracker.TrackedReal{Float64})
Closest candidates are:
  Float64(::Real, !Matched::RoundingMode) where T<:AbstractFloat at rounding.jl:185
  Float64(::T<:Number) where T<:Number at boot.jl:725
  Float64(!Matched::Int8) at float.jl:60
  ...

Stacktrace:
 [1] renyi_divergence(::TrackedArray{…,Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}, ::Float64) at /home/tor/.julia/dev/Distances/src/metrics.jl:382
 [2] top-level scope at In[35]:1

A possible "solution":

@inline Base.@propagate_inbounds function eval_start(::RenyiDivergence, a::AbstractArray{T_}, b::AbstractArray{T_}) where {T_ <: Real}
    T = eltype(a)  # returns `TrackedReal` if `a isa TrackedArray`
    zero(T), zero(T), T(sum(a)), T(sum(b))
end

Thoughts?

torfjelde avatar Aug 24 '19 19:08 torfjelde

This just seems like a bug in TrackedArray. Why isn't the parameter T a TrackedReal?

KristofferC avatar Aug 24 '19 19:08 KristofferC

Disclaimer: my understanding of Tracker.jl is limited, to say the least.

But I think a bug is a bit strong, as I'm pretty sure it's intended. By tracking the entire array, you can use matrix-operations, etc. to pull back the gradient / "backprop" rather than perform the operations elementwise on a Array{TrackedReal}.

EDIT: I agree though; it's unfortunate :/

torfjelde avatar Aug 24 '19 19:08 torfjelde

I don't really understand your comment. The docs for AbstractArray says

  AbstractArray{T,N}

  Supertype for N-dimensional arrays (or array-like types) with elements of type T

That doesn't seem to be the case for TrackedArray.

KristofferC avatar Aug 24 '19 20:08 KristofferC

Yeah, you're 100% right:) But I don't think there's a way to implement reverse-mode AD to take advantage of what I'm referring to while still satisfying the AbstractArray definition.

I'm wondering if it's okay to use eltype in this function call rather than the parametric type for the conversion call to allow this "abuse" of the AbstracArray defintion in Tracker.jl. Since Base.eltype(a::AbstractArray{T}) = T anyways, there won't be any negative side effects; the only outcome is that it just works on a TrackedArray.

torfjelde avatar Aug 24 '19 20:08 torfjelde

I'm wondering if it's okay to use eltype in this function call rather than the parametric type for the conversion call to allow this "abuse" of the AbstracArray defintion in Tracker.jl. Since Base.eltype(a::AbstractArray{T}) = T anyways, there won't be any negative side effects; the only outcome is that it just works on a TrackedArray.

If we started doing that, a lot of code would have to be updated. Given that the guaranty that the AbstractArray parameter is equivalent to eltype is relied upon in many packages (and in Base), a proper fix will have to be found instead.

nalimilan avatar Sep 04 '19 13:09 nalimilan

Understandable :+1: I guess the only proper fix would be to dispatch differently on TrackedArray for the 3 different signature combinations.

torfjelde avatar Sep 04 '19 13:09 torfjelde