ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

forward and reverse mode automatic differentiation primitives for Julia Base + StdLibs

Results 149 ChainRules.jl issues
Sort by recently updated
recently updated
newest added

[This](https://github.com/JuliaStats/Distributions.jl/blob/c9d6c28f415025bf489ac3bec2f8eec46b0eefbd/src/genericrand.jl#L48) fallback method for `rand` in `Distributions.jl` hits [this](https://github.com/JuliaDiff/ChainRules.jl/blob/f13e0a45d10bb13f48d6208e9c9d5b4a52b96732/src/rulesets/Random/random.jl#L25) rule, which is declared non-differentiable. This results in a silent failure, where there ought to be an error if the given...

```julia A = [1.0 1.0; 1.0 1.0] det(A) * inv(A) # ERROR: SingularException(2) ``` to be fair, the LU approach is remarkably resiliant... ```julia A = [1.0 1.0; 1.0 1.0-1e-16]...

This PR fixes #576 by treating zero (co)tangents in `sqrt` as strong zeros. It partially fixes https://github.com/FluxML/Zygote.jl/issues/1101 also, but to fix it entirely, we would need to do the same...

needs version bump

This only happens when the (co)tangent is 0. ```julia julia> using ChainRules julia> ChainRules.frule((ChainRules.ZeroTangent(), 0.0), sqrt, 0.0) (0.0, NaN) julia> ChainRules.rrule(sqrt, 0.0)[2](0.0) (ChainRulesCore.NoTangent(), NaN) ``` I suggest we adopt the...

bug

As noted in #504, there are a number of cases where types of rules were constrained to `CommutativeMulNumber` where commutation of multiplication did not need to be assumed. Likewise, there...

needs version bump

the `reshape` was a primative version of ProjectTo i think?

needs version bump

Fixes https://github.com/FluxML/Zygote.jl/issues/1037

From dfdx/Yota.jl#93: ```julia A = rand(100, 100) x = rand(100) rrule(*, x', A, x) # ==> nothing ``` It's possible to binarize the operation on the AD engine side, but...

missing rule

Requested here: https://discourse.julialang.org/t/implementation-of-spectral-normalization-for-machine-learning/76074 The workaround is to call `svd(X).S`, which is slower forwards. But it looks like the gradient calculation with something like `svd_rev((; U=NoTangent(), s=s, V=NoTangnet(), Vt=NoTangent()), NoTangent(), S̄,...

missing rule

I suspect some of the scalar rules should be using `oneunit` instead of `one`. For example, `sign`.

bug