ChainRules.jl
ChainRules.jl copied to clipboard
Remove `@non_differentiable` for `trunc`?
I saw that the @non_differentiable tag was removed for floor, ceil, and round recently (https://github.com/JuliaDiff/ChainRules.jl/commit/1770bb29ca42d4e07643284e0a4917ad6ea35b57). Should the same thing be done for trunc?
https://github.com/JuliaDiff/ChainRules.jl/blob/7faaf5d540e1249e6b2087c1d52cbef4b85c58e8/src/rulesets/Base/nondiff.jl#L433
I am trying to take the gradient of a function that returns an index, so maybe I'm thinking about this the wrong way anyways.
I think that commit removed them as duplicates of these:
https://github.com/JuliaDiff/ChainRules.jl/blob/cc8b9ea103abb20f9bd7016c561dd77080ac49d0/src/rulesets/Base/base.jl#L178
(Different representation of zero, which is another story. And not sure that x -> round(x, digits=8) should differ so much from x -> convert(Float32, x)...)
Mostly I would say things like indices aren't differentiable, as you can't consider changing them by a tiny amount. Can you say more what you have in mind?