Tracker.jl icon indicating copy to clipboard operation
Tracker.jl copied to clipboard

Behavior on `vcat` of scalars

Open gdalle opened this issue 10 months ago • 2 comments

Here's an inconsistency I discovered based on two ways of constructing vectors from floats:

julia> using Tracker

julia> f1(x) = [x, x];

julia> f2(x) = vcat(x, x);

julia> x = 1.0;

julia> dy = ones(2);

julia> y1, pb1 = Tracker.forward(f1, x);

julia> y2, pb2 = Tracker.forward(f2, x);

julia> y1 == y2
true

julia> pb1(dy) == pb2(dy)
false

julia> pb1(dy)  # returns a scalar
(2.0,)

julia> pb2(dy)  # returns an array
([2.0],)

gdalle avatar Jan 29 '25 10:01 gdalle

Shorter example:

julia> Tracker.gradient(sum∘vcat, [1.0, 2.0], [3.0])  # fine
([1.0, 1.0] (tracked), [1.0] (tracked))

julia> Tracker.gradient(sum∘vcat, 1.0, 2.0)  # creates a vector where it wants a scalar
ERROR: MethodError: Cannot `convert` an object of type Vector{Float64} to an object of type Float64

Stacktrace:
  [1] setproperty!(x::Tracker.Tracked{Float64}, f::Symbol, v::Vector{Float64})
    @ Base ./Base.jl:52
  [2] back(x::Tracker.Tracked{Float64}, Δ::Vector{Float64}, once::Bool)
    @ Tracker ~/.julia/packages/Tracker/6rnwO/src/back.jl:48
  [3] (::Tracker.var"#707#708"{Bool})(x::Tracker.Tracked{Float64}, d::Vector{Float64})
    @ Tracker ~/.julia/packages/Tracker/6rnwO/src/back.jl:38

julia> Tracker.gradient(sum∘vcat, [1.0, 2.0], 3.0)  # same error with mix of vector & scalar args
ERROR: MethodError: Cannot `convert` an object of type Vector{Float64} to an object of type Float64

Note that I wouldn't call this "inconsistent", it seems to always go wrong. The other function f1(x) = [x, x] is Base.vect and not vcat, see e.g. Meta.@lower [1, 2]. The syntax for vcat has a semicolon, [1; 2].

mcabbott avatar Jan 29 '25 15:01 mcabbott

https://github.com/FluxML/Tracker.jl/blob/ce231f93b275079bf8390724a383eb762c092791/src/lib/array.jl#L214-L220 looks to be tailored for arrays only. hcat has a similar code path and thus a similar problem. Base.vect does not appear to have a rule, so Tracker traces through it transparently and the individual tracked scalars remain intact.

ToucheSir avatar Jan 29 '25 17:01 ToucheSir