ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

Remove bad non-differentiable

Open willtebbutt opened this issue 3 years ago • 5 comments

Resolves #603

willtebbutt avatar Mar 28 '22 20:03 willtebbutt

Looks like this currently causes some downstream issues -- will attempt to isolate the particular rules that we need to keep Zygote and Diffractor tests passing.

willtebbutt avatar Mar 29 '22 08:03 willtebbutt

Diffractor never passes.

Zygote failures are things like this:

/home/runner/work/ChainRules.jl/ChainRules.jl/downstream/test/gradcheck.jl:1730
[313](https://github.com/JuliaDiff/ChainRules.jl/runs/5733081278?check_suite_focus=true#step:6:313)
  Test threw exception
[314](https://github.com/JuliaDiff/ChainRules.jl/runs/5733081278?check_suite_focus=true#step:6:314)
  Expression: gradient((x->begin
[315](https://github.com/JuliaDiff/ChainRules.jl/runs/5733081278?check_suite_focus=true#step:6:315)
                sum(rand(Random.GLOBAL_RNG, Float32, 1, 1))
[316](https://github.com/JuliaDiff/ChainRules.jl/runs/5733081278?check_suite_focus=true#step:6:316)
            end), 1) == (nothing,)

for which the immediate fix is that there should be rules something like

@non_differentiable rand(::Type{<:Number}, ::AbstractRNG, ::Tuple)
@non_differentiable rand(::Type{<:Number}, ::AbstractRNG, ::Integer...)

mcabbott avatar Mar 29 '22 17:03 mcabbott

This seems suprising since the PR does not touch rand, I would assume removing rules for rand! should not affect something like sum(rand(args...))?

Edit: Ah, probably rand(Float32, 1, 1) calls rand! internally...

devmotion avatar Mar 29 '22 17:03 devmotion

Yes, the method is:

rand(r::AbstractRNG, ::Type{X}, dims::Dims) where {X} = rand!(r, Array{X}(undef, dims), X)

If we allow rand!(rng, array) then it's possible we should allow rand!(rng, array, eltype) too.

I'd rather not have any of these mutating functions, but removing them might break too many things.

mcabbott avatar Mar 29 '22 17:03 mcabbott

Stumbled on this after a long dive (w/ help from Zack Li and Brian Chen on Slack) trying to understand this incorrectly dropped gradient:

gradient(σ -> sum(rand(Xoshiro(1), MvNormal(zeros(2), σ*I))), 1) # nothing

Conversely, with this PR, instead of dropping it, you (correctly) get an error since the rand code in Distributions.jl is mutating. If you load DistributionsAD, then it works. But yea would be good not to silently and incorrectly drop it, so +1 from me on this PR (at least w.r.t. to this, no comment on anything else breaking)

marius311 avatar Aug 01 '22 22:08 marius311