Tracker.jl
Tracker.jl copied to clipboard
Broadcasting and constructors
There seems to be something weird going on when broadcasting over Real
:
julia> using Tracker, Distributions
julia> m = first(param(zeros(1)))
0.0 (tracked)
julia> s = first(param(ones(1)))
1.0 (tracked)
julia> typeof(Normal.(m, s))
Tracker.Tracked{Normal{Float64}}
But the strange thing is that the following seems to work just fine:
julia> struct TwoFields{T1,T2}
x::T1
y::T2
end
julia> TwoFields.(first(param(zeros(1))), first(param(zeros(1))))
TwoFields{Float64, Float64}(0.0, 0.0)
Possibly related to: #65
Maybe https://github.com/JuliaStats/Distributions.jl/blob/master/src/common.jl#L149 messes with Tracker's broadcasting? What happens if you add the same definition for your custom struct?
Looks fine:
julia> using Tracker
julia> struct TwoFields{T1,T2}
x::T1
y::T2
end
julia> Broadcast.broadcastable(d::TwoFields) = Ref(d)
julia> TwoFields.(first(param(zeros(1))), first(param(zeros(1))))
TwoFields{Float64, Float64}(0.0, 0.0)
OK, I figured out why this is happening: It's caused by the heuristic in https://github.com/FluxML/Tracker.jl/blob/043da25fe491a5c705898e4de90cb8567f93373f/src/lib/array.jl#L558 and the definition of eltype(::Normal{T}) = T
(well, actually Distributions defines it on types as recommended in the Julia docs since instances fall back to it).
If one defines eltype
for TwoFields
such that it returns a subtype of Real
the same behaviour can be observed, eg.
julia> Base.eltype(::Type{TwoFields{T1,T2}}) where {T1,T2} = Base.promote_type(T1, T2)
julia> typeof(TwoFields.(first(param(zeros(1))), first(param(zeros(1)))))
Tracker.Tracked{TwoFields{Float64, Float64}}
Aaaah..
So is the fix to also check y isa AbstractArray
or something?
I'm not completely sure about the motivation of this check but to me it seems the heuristic is supposed to drop tracking information in cases where the output is known to be non-differentiable. I think the heuristic should rather be too strict and avoid dropping tracking information silently. Maybe eg it could be restricted to y isa Union{Bool,AbstractArray{<:Bool}}
. I'm worried though that changes in this heuristic break many downstream packages.
Anything with eltype Bool
gets rejected (including bools themselves), so as long as the conditional remains the same, we should be compatible with the rest of the packages. I would test Turing/ SciML against such a branch to be safe regardless. Unless someone is relying on the behaviour in the MWE.
Anything with eltype
Bool
gets rejected (including bools themselves), so as long as the conditional remains the same,
The fix for Distributions (and e.g. samplers in general: https://docs.julialang.org/en/v1/stdlib/Random/#A-simple-sampler-without-pre-computed-data) requires to change, i.e., probably restrict, the conditional - eltype
is not only used for arrays or standard containers but also in other settings which causes the problems in this PR. For instance, eltype(::Bernoulli) = Bool
(https://github.com/JuliaStats/Distributions.jl/blob/5cb0bfc0383180341c19a478ba859190b5b728a0/src/univariate/discrete/bernoulli.jl#L43) but it is completely reasonable to differentiate loglikelihood(Bernoulli(p), x) = sum(x) * log(p) + (length(x) - sum(x)) * log1p(-p)
with respect to parameter p
where p
is a TrackedReal
and x
is an untracked Array{Bool}
. So it seems a fix would have to make sure that not anything with eltype Bool
drops the tracking information.
I don't quite understand what ∇broadcast
is doing and why it differs from https://github.com/FluxML/Zygote.jl/blob/v0.6.51/src/lib/broadcast.jl#L188-L207, but perhaps Tracker could borrow some of Zygote's logic here.