Tracker.jl icon indicating copy to clipboard operation
Tracker.jl copied to clipboard

`vcat` with tracked reals

Open MikeInnes opened this issue 7 years ago • 1 comments

vcat(param(1),param(2:3))

This currently returns an array of tracked scalars; would be good to catch it in the tracking code.

MikeInnes avatar Jun 08 '18 16:06 MikeInnes

Actually any hcat or vcat that includes a TrackedReal fails with a method ambiguity in Flux v0.8.1 / Tracker v0.1.0, owing to TrackedReal<:Number. I've tried to fix this but I don't know enough. IMO the entire dispatching scheme used for hcat/vcat in lib/array.jl is broken - I'd have thought that this was a case for defining ones (probably Tracker.vcat not Base.vcat) with trait-based dispatch using any() on a "tracked value" trait such as typeof(x)<:Union{TrackedReal,TrackedArray,...}.

using Flux.Tracker vcat(param(1),param(2))

ERROR: MethodError: vcat(::Tracker.TrackedReal{Float64}, ::Tracker.TrackedReal{Float64}) is ambiguous. Candidates: vcat(368::Number, x::Union{TrackedArray, TrackedReal}, xs::Union{Number, AbstractArray}...) in Tracker at /home/triggs/.julia/packages/Tracker/6wcYJ/src/lib/array.jl:167 vcat(x::Union{TrackedArray, TrackedReal}, xs::Union{Number, AbstractArray}...) in Tracker at /home/triggs/.julia/packages/Tracker/6wcYJ/src/lib/array.jl:167 Possible fix, define vcat(::Tracker.TrackedReal, ::Union{TrackedArray, TrackedReal}, ::Vararg{Union{Number, AbstractArray},N} where N) Stacktrace: [1] top-level scope at none:0

Further to this, it's great that you fixed Tracker.collect on TrackedReals to return a TrackedArray, but on untracked values it isn't falling back to untracked Base.collect correctly:

Tracker.collect([1.0,2.0]) Tracked 2-element Array{Float64,1}: 1.0 2.0

This could be fixed by switching on istracked: function Tracker.collect(xs) xs = Base.collect(xs) return any(istracked,xs) ? track(Call(collect, (tracker.(xs),)), data.(xs)) : xs end

BillTriggs avatar Mar 28 '19 16:03 BillTriggs