ChainRules.jl
ChainRules.jl copied to clipboard
forward and reverse mode automatic differentiation primitives for Julia Base + StdLibs
Because closures are not consistent types, they inhibit precompilation. If the closures used for the adjoints were instead callable types, which would make sure that every session has the same...
The below issue contains an MWE of Zygote erroring on the `init` kwarg to `Base.sum`. Best way to fix it would be to define an `rrule` for `sum` https://github.com/FluxML/Zygote.jl/issues/1279
Resolves #603
Some rules use scalar indexing, breaking GPU compatibility, e.g., https://github.com/JuliaDiff/ChainRules.jl/blob/3b3791f10bc88c41f004fbb9eb229745d1764593/src/rulesets/LinearAlgebra/norm.jl#L187 One solution would be to use `@allowscalar` from GPUArrays, but one concern about adding that dependency is the loading time...
```julia julia> func(x) = sum(repeat(x, inner = (1, 3))) func (generic function with 1 method) julia> func(CUDA.rand(2,3)) 7.5995846f0 julia> Zygote.gradient(func,CUDA.rand(2,3)) ERROR: Scalar indexing is disallowed. Invocation of getindex resulted in...
```julia julia> rrule(repeat, falses(1), 1) ERROR: MethodError: rrule(::typeof(repeat), ::BitVector, ::Int64) is ambiguous. Candidates: rrule(::typeof(repeat), xs::AbstractArray, counts::Integer...) in ChainRules at /home/brianc/.julia/packages/ChainRules/o1vND/src/rulesets/Base/array.jl:191 rrule(::typeof(repeat), var"543"::AbstractArray{Bool}, var"544"...; repeat_pullback) in ChainRules at /home/brianc/.julia/packages/ChainRules/o1vND/src/rulesets/Base/nondiff.jl:65 Possible fix,...
Closes #567, perhaps in the minimal way, by attaching these rules to internal function which take positional arguments. Gradient for `init` is just `@not_implemented` for now. One nice effect is...
Zygote does not seem smart enough to handle this on its own. Pending questions: - [ ] How to get tests working. I assume https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/195 would be required. - [...
I saw that the `@non_differentiable` tag was removed for floor, ceil, and round recently (https://github.com/JuliaDiff/ChainRules.jl/commit/1770bb29ca42d4e07643284e0a4917ad6ea35b57). Should the same thing be done for `trunc`? https://github.com/JuliaDiff/ChainRules.jl/blob/7faaf5d540e1249e6b2087c1d52cbef4b85c58e8/src/rulesets/Base/nondiff.jl#L433 I am trying to take the...
The `@test_skip` confused me, i though the test failed, but no it is just that CRTU is not good at this https://github.com/JuliaDiff/ChainRulesTestUtils.jl/issues/219 (NFC == no functional change)