`RenyiDivergence` incompatible with TrackedArray
MVE
using Tracker, Distances
renyi_divergence(param([1.0]), param([1.0]), 0.5)
results in
MethodError: no method matching Float64(::Tracker.TrackedReal{Float64})
Closest candidates are:
Float64(::Real, !Matched::RoundingMode) where T<:AbstractFloat at rounding.jl:185
Float64(::T<:Number) where T<:Number at boot.jl:725
Float64(!Matched::Int8) at float.jl:60
...
Stacktrace:
[1] renyi_divergence(::TrackedArray{…,Array{Float64,1}}, ::TrackedArray{…,Array{Float64,1}}, ::Float64) at /home/tor/.julia/dev/Distances/src/metrics.jl:382
[2] top-level scope at In[35]:1
A possible "solution":
@inline Base.@propagate_inbounds function eval_start(::RenyiDivergence, a::AbstractArray{T_}, b::AbstractArray{T_}) where {T_ <: Real}
T = eltype(a) # returns `TrackedReal` if `a isa TrackedArray`
zero(T), zero(T), T(sum(a)), T(sum(b))
end
Thoughts?
This just seems like a bug in TrackedArray. Why isn't the parameter T a TrackedReal?
Disclaimer: my understanding of Tracker.jl is limited, to say the least.
But I think a bug is a bit strong, as I'm pretty sure it's intended. By tracking the entire array, you can use matrix-operations, etc. to pull back the gradient / "backprop" rather than perform the operations elementwise on a Array{TrackedReal}.
EDIT: I agree though; it's unfortunate :/
I don't really understand your comment. The docs for AbstractArray says
AbstractArray{T,N}
Supertype for N-dimensional arrays (or array-like types) with elements of type T
That doesn't seem to be the case for TrackedArray.
Yeah, you're 100% right:) But I don't think there's a way to implement reverse-mode AD to take advantage of what I'm referring to while still satisfying the AbstractArray definition.
I'm wondering if it's okay to use eltype in this function call rather than the parametric type for the conversion call to allow this "abuse" of the AbstracArray defintion in Tracker.jl. Since Base.eltype(a::AbstractArray{T}) = T anyways, there won't be any negative side effects; the only outcome is that it just works on a TrackedArray.
I'm wondering if it's okay to use
eltypein this function call rather than the parametric type for the conversion call to allow this "abuse" of theAbstracArraydefintion in Tracker.jl. SinceBase.eltype(a::AbstractArray{T}) = Tanyways, there won't be any negative side effects; the only outcome is that it just works on aTrackedArray.
If we started doing that, a lot of code would have to be updated. Given that the guaranty that the AbstractArray parameter is equivalent to eltype is relied upon in many packages (and in Base), a proper fix will have to be found instead.
Understandable :+1:
I guess the only proper fix would be to dispatch differently on TrackedArray for the 3 different signature combinations.