Michael Abbott
Michael Abbott
https://github.com/JuliaDiff/ChainRules.jl/pull/655 implements the same copy-to-CPU idea is [98e87c3](https://github.com/JuliaDiff/ChainRules.jl/pull/655/commits/98e87c3e7a560ce094969be6df49e569d37fc61a). Comments (or benchmarking) would be welcome. Adding test cases here would also be a great idea, as they will be run on...
`foldl` not tracking `init` keyword is https://github.com/JuliaDiff/ChainRules.jl/issues/567, you could try with https://github.com/JuliaDiff/ChainRules.jl/pull/569 `sum` not supporting `init` is also bad, could you make an issue on ChainRules.jl? ```julia julia> ChainRules.rrule(sum, [1,2,3];...
Yes. I guess the thing you overload should eventually be something like https://github.com/JuliaDiff/ChainRulesCore.jl/pull/528 However, at the moment I believe you get errors from `unbroadcast` not having appropriate methods, if you...
Never got back to https://github.com/JuliaGPU/GPUArrays.jl/issues/362 . But the rule could now use `@allowscalar` via GPUArraysCore instead.
Fixed in ChainRules, but I think Zygote still uses its own older versions: ``` julia> Diffractor.gradient(randn(3,3)) do x sum(sin.(cat(x; dims=4))) end[1] 3×3 Matrix{Float64}: 0.954619 0.903961 0.297481 0.174126 0.99136 0.972581 0.53144...
FWIW, this is because `gradient` calls ProjectTo on the final answer. The rule itself is unchanged, and thus intermediate results may show this. ```julia julia> Zygote.pullback(randn(3,3)) do x sum(sin.(cat(x; dims=4)))...
IIRC the hurdle to simply deleting all of these is https://github.com/JuliaGPU/GPUArrays.jl/issues/362 . `vcat` of a mix of numbers and CuArrays mostly works, and its gradient should not use scalar indexing....
Globals are evil. I think zero is correct, as the model is linear in `x`. If you change it to use `tanh` then these all agree: ``` julia> model Dense(3...
This is only about the weird `adjoint` overload, `gradient` is fine: ```julia julia> gradient(one, 0.0) julia> one'(0.0) ERROR: MethodError: no method matching getindex(::Nothing, ::Int64) ```
You can fix the error with `ChainRulesCore.@non_differentiable searchsortedlast(x, y)`, which ideally should be added here: https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Base/nondiff.jl#L380 . The immediate error is in fact `[5] (::typeof(∂(>>>)))(Δ::Int64)` and bit-shift functions like `>>>`...