Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

`map` overreach?

Open willtebbutt opened this issue 5 years ago • 3 comments

Zygote's current map implementation is arguably a bit optimistic about the types of things that it's able to handle.

For example, this issue in KernelFunctions.jl cropped up because we define a custom AbstractVector type that wraps a Matrix, and lets it masquerade as a vector-of-vectors.

Under the hood, this type makes sure to implement various operations efficiently on the wrapped matrix. It would be reasonable to assume that Zygote would be able to exploit these efficient implementations (because composition), but instead it hits the map adjoint and literally treats the object as a vector-of-vectors, which is bad for performance.

I would propose to impose further type constraints on the implementation of map, perhaps to StridedArray or DenseArray, whichever is deemed a better target. @MikeInnes @dhairyagandhi96 any thoughts?

willtebbutt avatar May 14 '20 10:05 willtebbutt

Sounds fine to me. This is kind of a tricky tradeoff unfortunately; probably the only real answer is to delete the adjoint entirely and support differentiating through map, but that won't work right now.

MikeInnes avatar May 18 '20 15:05 MikeInnes

I kind of agree. There's definitely something to be exploited in knowing that map acts independently on each of the elements of its input, be it in compile times or run-time performance, so writing custom rules feels to me like something of a no-brainer here -- it's just a question of the level of generality at which you implement them.

I'll try to remember to make a PR on this soon.

willtebbutt avatar May 18 '20 17:05 willtebbutt

Perhaps we could kill two birds with one stone if https://github.com/JuliaDiff/ChainRules.jl/issues/314 gets implemented. Moving to ChainRules would also get us ProjectTo, which in theory could handle more array types (modulo efficiency concerns).

ToucheSir avatar May 09 '22 14:05 ToucheSir