Tracker.jl icon indicating copy to clipboard operation
Tracker.jl copied to clipboard

Broadcasting and constructors

Open torfjelde opened this issue 2 years ago • 8 comments

There seems to be something weird going on when broadcasting over Real:

julia> using Tracker, Distributions

julia> m = first(param(zeros(1)))
0.0 (tracked)

julia> s = first(param(ones(1)))
1.0 (tracked)

julia> typeof(Normal.(m, s))
Tracker.Tracked{Normal{Float64}}

But the strange thing is that the following seems to work just fine:

julia> struct TwoFields{T1,T2}
           x::T1
           y::T2
       end

julia> TwoFields.(first(param(zeros(1))), first(param(zeros(1))))
TwoFields{Float64, Float64}(0.0, 0.0)

Possibly related to: #65

torfjelde avatar Jan 13 '22 21:01 torfjelde

Maybe https://github.com/JuliaStats/Distributions.jl/blob/master/src/common.jl#L149 messes with Tracker's broadcasting? What happens if you add the same definition for your custom struct?

devmotion avatar Jan 13 '22 21:01 devmotion

Looks fine:

julia> using Tracker

julia> struct TwoFields{T1,T2}
           x::T1
           y::T2
       end

julia> Broadcast.broadcastable(d::TwoFields) = Ref(d)

julia> TwoFields.(first(param(zeros(1))), first(param(zeros(1))))
TwoFields{Float64, Float64}(0.0, 0.0)

torfjelde avatar Jan 13 '22 21:01 torfjelde

OK, I figured out why this is happening: It's caused by the heuristic in https://github.com/FluxML/Tracker.jl/blob/043da25fe491a5c705898e4de90cb8567f93373f/src/lib/array.jl#L558 and the definition of eltype(::Normal{T}) = T (well, actually Distributions defines it on types as recommended in the Julia docs since instances fall back to it).

If one defines eltype for TwoFields such that it returns a subtype of Real the same behaviour can be observed, eg.

julia> Base.eltype(::Type{TwoFields{T1,T2}}) where {T1,T2} = Base.promote_type(T1, T2)

julia> typeof(TwoFields.(first(param(zeros(1))), first(param(zeros(1)))))
Tracker.Tracked{TwoFields{Float64, Float64}}

devmotion avatar Jan 13 '22 22:01 devmotion

Aaaah..

So is the fix to also check y isa AbstractArray or something?

torfjelde avatar Jan 14 '22 15:01 torfjelde

I'm not completely sure about the motivation of this check but to me it seems the heuristic is supposed to drop tracking information in cases where the output is known to be non-differentiable. I think the heuristic should rather be too strict and avoid dropping tracking information silently. Maybe eg it could be restricted to y isa Union{Bool,AbstractArray{<:Bool}}. I'm worried though that changes in this heuristic break many downstream packages.

devmotion avatar Jan 14 '22 16:01 devmotion

Anything with eltype Bool gets rejected (including bools themselves), so as long as the conditional remains the same, we should be compatible with the rest of the packages. I would test Turing/ SciML against such a branch to be safe regardless. Unless someone is relying on the behaviour in the MWE.

DhairyaLGandhi avatar Jan 14 '22 17:01 DhairyaLGandhi

Anything with eltype Bool gets rejected (including bools themselves), so as long as the conditional remains the same,

The fix for Distributions (and e.g. samplers in general: https://docs.julialang.org/en/v1/stdlib/Random/#A-simple-sampler-without-pre-computed-data) requires to change, i.e., probably restrict, the conditional - eltype is not only used for arrays or standard containers but also in other settings which causes the problems in this PR. For instance, eltype(::Bernoulli) = Bool (https://github.com/JuliaStats/Distributions.jl/blob/5cb0bfc0383180341c19a478ba859190b5b728a0/src/univariate/discrete/bernoulli.jl#L43) but it is completely reasonable to differentiate loglikelihood(Bernoulli(p), x) = sum(x) * log(p) + (length(x) - sum(x)) * log1p(-p) with respect to parameter p where p is a TrackedReal and x is an untracked Array{Bool}. So it seems a fix would have to make sure that not anything with eltype Bool drops the tracking information.

devmotion avatar Jan 15 '22 21:01 devmotion

I don't quite understand what ∇broadcast is doing and why it differs from https://github.com/FluxML/Zygote.jl/blob/v0.6.51/src/lib/broadcast.jl#L188-L207, but perhaps Tracker could borrow some of Zygote's logic here.

ToucheSir avatar Jan 06 '23 06:01 ToucheSir