ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

Remove `@non_differentiable` for `trunc`?

Open eigenvivek opened this issue 3 years ago • 1 comments

I saw that the @non_differentiable tag was removed for floor, ceil, and round recently (https://github.com/JuliaDiff/ChainRules.jl/commit/1770bb29ca42d4e07643284e0a4917ad6ea35b57). Should the same thing be done for trunc?

https://github.com/JuliaDiff/ChainRules.jl/blob/7faaf5d540e1249e6b2087c1d52cbef4b85c58e8/src/rulesets/Base/nondiff.jl#L433

I am trying to take the gradient of a function that returns an index, so maybe I'm thinking about this the wrong way anyways.

eigenvivek avatar Jul 19 '22 15:07 eigenvivek

I think that commit removed them as duplicates of these:

https://github.com/JuliaDiff/ChainRules.jl/blob/cc8b9ea103abb20f9bd7016c561dd77080ac49d0/src/rulesets/Base/base.jl#L178

(Different representation of zero, which is another story. And not sure that x -> round(x, digits=8) should differ so much from x -> convert(Float32, x)...)

Mostly I would say things like indices aren't differentiable, as you can't consider changing them by a tiny amount. Can you say more what you have in mind?

mcabbott avatar Jul 19 '22 16:07 mcabbott