horde-ad icon indicating copy to clipboard operation
horde-ad copied to clipboard

Add `tupdate` to `Tensor` class and start simplifying `tscatter`

Open Mikolaj opened this issue 1 year ago • 8 comments

It should be such that tupdate (tzero sh) ix v is the transpose of tindex v ix. Also

https://github.com/Mikolaj/horde-ad/blob/6f88617de23a4d9fb328b352cf43fcf4cffd97b8/simplified/HordeAd/Core/AstSimplify.hs#L433

Probably tscatter can then be simplified using tupdate similarly as tgather simplifies using tindex right now. I'm not sure how much of the current complex tgather simplification code would dualize, but at least the trivial cases should do and they offer great benefits whenever they apply.

I suppose, we'd also need an Ast term for the operation, vectorization rules and forward pass and transpose rules. A similar operation is already implemented at the low level, because it's needed too implement scatter:

https://github.com/Mikolaj/horde-ad/blob/6f88617de23a4d9fb328b352cf43fcf4cffd97b8/src/common/HordeAd/Internal/TensorOps.hs#L101-L117

This needs to be generalized to non-singleton indexes but, OTOH, it can be specialized to just one update, at least initially.

Overall, this ticket is a big chunk of work, but quite modular. A couple of parts, but probably intertwined with others, are crucial for performance of the simplified horde-ad.

Mikolaj avatar Apr 16 '23 08:04 Mikolaj