rrule for fill!
related to #515
I think the difficulty with allowing this is that it will cause any other rule which has captured x to give wrong answers:
julia> Zygote.gradient([1,2,3]) do x
y = log.(x)
fill!(x, 0)
sum(y .+ x)
end[1]
3-element Vector{Float64}:
2.0
1.5
1.3333333333333333
julia> ForwardDiff.gradient([1,2,3]) do x
y = log.(x)
fill!(x, 0)
sum(y .+ x)
end
3-element Vector{Float64}:
1.0
0.5
0.3333333333333333
Is there some clever way we might avoid this, or at least, make this example where x is used elsewhere an error, without making fill!(similar(x), y) an error? What if the gradient with respect to x is some KillerTangent(), which explodes on contact but is quietly thrown away by the rrule for similar, etc?
Can't we trust users to use this in a safe way? We do it already for rand!.
It's harder for me to picture rand! going wrong in the wild, but it does have the same problem.
Seems to be from #252, without discussion.
I don't see a general solution for mutating functions violating implicit assumptions made by other rrules.
Giving up entirely on them is a bit annoying, forces users into awkward alternative paths, or into using ignore blocks or into defining their own rrules. On the other hand, if we implement rules for mutating functions and people use them too freely they are going to shoot themselves in the foot.
Taking a general stance on this requires more thought. In this specific case though, given how much the fill!(similar(x), y) pattern appears in the wild, I would be more on the permissive side and go with something like this PR.
PS I don't understand what the test failure means
What if the gradient with respect to
xis someKillerTangent(), which explodes on contact but is quietly thrown away by therrulefor similar, etc?
We do have NotImplemented which kind of does that.
It poison's everything it touches making that also return NotImplemented with the same message:
Pretty much any time you pullback a NotImplemented you get the same NotImplemented.
And what would happen for
x = similar(...)
y = fill!(x, a)
would be we call pullback_fill! and get x̄=NotImplemented(),
which we pass to pullback_similar but that returns NoTangent() for all it's inputs since it is not differentiable anyway.
And in the
julia> Zygote.gradient([1,2,3]) do x
y = log.(x)
fill!(x, 0)
sum(y .+ x)
end[1]
case
then the broadcast_pullback will get a NotImplemented which it will then pass on to the, until at the end the user gets a NotImplemented as the output.
And it should display some nice message explaining about mutation not being supported.
(we probably also want a way to turn NotImplemented construction into errors so they can workout when in their code they called e.g fill!)
Oh right, I guess NotImplemented must have roughly the rules I imagined.
However, my original example is trickier, since the pullback never gets called.
julia> function ChainRulesCore.rrule(::typeof(fill!), A::Vector, x::Number)
project = x isa Union{Number, AbstractArray{<:Number}} ? ProjectTo(x) : identity
fill!_pullback(Ȳ) = (NoTangent(), @not_implemented("nope"), project(sum(Ȳ)))
return fill!(A, x), fill!_pullback
end
julia> Zygote.gradient([1,2,3], 0) do x, s
y = fill!(similar(x), s)
sum(y)
end
(nothing, 3.0)
julia> Zygote.gradient([1,2,3], 0) do x, s
y = log.(x)
z = fill!(x, s)
sum(y .+ z)
end
(NotImplemented(Main, #= REPL[7]:3 =#, nope), 3.0)
julia> Zygote.gradient([1,2,3], 0) do x, s
y = log.(x)
fill!(x, s) # still leads to silent errors
sum(y .+ x)
end
([2.0, 1.5, 1.3333333333333333], nothing)
Huh, is the difference about if it was assigned to z or not?
I didn't know Zygote reasons about variables in that way.
I wonder if that could be changed>
I don't know the internals, maybe this could be changed? It calls the rrule on the forward pass, but not the pullback, I think because what's returned by fill! isn't an input to any later function, hence there is no input for the pullback. (FWIW Diffractor fails on both of these, right now.)
Here's an update.
- Since the above, Zygote seems to have been taught to silently ignore
NotImplemented, giving more ways to get wrong answers. - Diffractor does always call the pullback. When the return of
fill!is not used, this will get Zero input. Perhaps the pullback should have a method for such cases. - Using
NotImplementedto markxas poisoned, without restoring its value, isn't broad enough. In the examples above,log.(x)captures the old value ofxand uses this forx's gradient. But something likex * ycaptures both values, and uses the old value ofxfory's gradient.
julia> using Zygote, ChainRulesCore, ForwardDiff, Diffractor
# New rule, with back(::Zero) method
julia> function ChainRulesCore.rrule(::typeof(fill!), A::Vector, x::Number)
function back(dB)
println("pullback for fill! got $dB")
(NoTangent(), @not_implemented("arg is mutated"), sum(dB))
end
function back(dB::AbstractZero)
println("pullback for fill! got $dB")
(NoTangent(), @not_implemented("mutated"), @not_implemented("no input"))
end
fill!(A,x), back
end
julia> Zygote.gradient([1,2], 3) do x, s # easy case, works as desired
y = fill!(similar(x), s)
sum(abs2, y)
end
pullback for fill! got [6, 6]
(nothing, 12.0)
# Example from above
julia> Zygote.gradient([1,2,3], 0) do x, s # silently wrong, both args!
y = log.(x) # this needs x's value
fill!(x, s) # pullback is not called
sum(y .+ x)
end
([2.0, 1.5, 1.3333333333333333], nothing)
julia> ForwardDiff.gradient([1,2,3]) do x
y = log.(x)
fill!(x, 0)
sum(y .+ x)
end
3-element Vector{Float64}:
1.0
0.5
0.3333333333333333
julia> Diffractor.gradient([1,2,3], 0) do x, s
y = log.(x) # this needs x's value
fill!(x, s) # poisons x, and s
sum(y .+ x)
end
pullback for fill! got ZeroTangent()
(NotImplemented(Main, #= REPL[22]:6 =#, mutated), NotImplemented(Main, #= REPL[22]:6 =#, no input))
# New example
julia> Zygote.gradient([1 2; 3 4], [5,6], 7) do x, y, z
xy = x * y
y2 = fill!(y, z)
sum(xy .+ y2)
end
pullback for fill! got [1.0, 1.0]
([7.0 7.0; 7.0 7.0], [4.0, 6.0], 2.0)
julia> Diffractor.gradient([1 2; 3 4], [5,6], 7) do x, y, z # silently wrong about x
xy = x * y # x's gradient needs y's value, etc.
y2 = fill!(y, z) # poisons y, but not x
sum(xy .+ y2)
end
pullback for fill! got [1.0, 1.0]
([7.0 7.0; 7.0 7.0], NotImplemented(Main, #= REPL[5]:4 =#, nope), 2.0)
julia> ForwardDiff.gradient([1 2; 3 4]) do x
y, z = [5,6], 7
xy = x * y
y2 = fill!(y, z)
sum(xy .+ y2)
end
2×2 Matrix{Int64}:
5 6
5 6
If you overload _pullback, then this is always called, even with no return. This appears to give safe answers on the above examples, often NaN:
function Zygote._pullback(__context__::Zygote.AContext, ::typeof(fill!), x::Array, v)
old = copy(x) # could instead just have fill!(x, NaN) on the reverse?
y = fill!(x, v)
back(::Nothing) = begin
copyto!(x, old) # restore
(nothing, Zygote.Fill(NaN, size(x)), NaN) # since we didn't see the return, poison it
end
back(dy) = begin
copyto!(x, old)
(nothing, Zygote.Fill(NaN, size(x)), sum(dy)) # here we know dv
end
return (y, back)
end
Similar for setindex!:
function Zygote._pullback(__context__::Zygote.AContext, ::typeof(Base.setindex!), x::Array, v, ind::Integer...)
old = x[ind...]
y = setindex!(x, v, ind...)
nots = map(_ -> nothing, ind)
back(::Nothing) = begin
x[ind...] = old
(nothing, Zygote.Fill(NaN, size(x)), NaN, nots...)
end
back(dy) = begin
x[ind...] = old
(nothing, Zygote.Fill(NaN, size(x)), dy, nots...) # setindex! returns the value
end
return (y, back)
end
Zygote.gradient([1,2,3.0], 4) do x, y
x[1] = y^2
sum(x .* y)
end # should be ([0,4,4], 53), in fact all NaN
Since the above, Zygote seems to have been taught to silently ignore NotImplemented, giving more ways to get wrong answers.
Damn it Zygote. Do we have an issue open downstream to "Please don't do this"?
We had https://github.com/FluxML/Zygote.jl/issues/1227 but it was closed, I've just re-opened it. The problem was that https://github.com/FluxML/Zygote.jl/issues/1204 happened, which lead to https://github.com/FluxML/Zygote.jl/pull/1205. As I mentioned in the issue, there doesn't seem to be a more incremental fix here than doing all of https://github.com/FluxML/Zygote.jl/issues/603. Am I missing a better solution?
Thinking about it a bit more, could we get away with just switching over Zygote's zero types (i.e. nothing)? The biggest obstacle I can think of is getting rid of internal pullback calls in higher-order rules to avoid premature conversion.
@mzgubic (with my help) tried to switch over Zygote's types a few years ago. It got hairy fast. Far more complex than you might think. Though now that Zygote has fewer rules that might need changing it might be easier
In the spirit of baby steps, I've filed https://github.com/FluxML/Zygote.jl/pull/1385 to provide a better base for future attempts at this.