horde-ad
horde-ad copied to clipboard
Add `tupdate` to `Tensor` class and start simplifying `tscatter`
It should be such that tupdate (tzero sh) ix v
is the transpose of tindex v ix
. Also
https://github.com/Mikolaj/horde-ad/blob/6f88617de23a4d9fb328b352cf43fcf4cffd97b8/simplified/HordeAd/Core/AstSimplify.hs#L433
Probably tscatter
can then be simplified using tupdate
similarly as tgather
simplifies using tindex
right now. I'm not sure how much of the current complex tgather
simplification code would dualize, but at least the trivial cases should do and they offer great benefits whenever they apply.
I suppose, we'd also need an Ast term for the operation, vectorization rules and forward pass and transpose rules. A similar operation is already implemented at the low level, because it's needed too implement scatter
:
https://github.com/Mikolaj/horde-ad/blob/6f88617de23a4d9fb328b352cf43fcf4cffd97b8/src/common/HordeAd/Internal/TensorOps.hs#L101-L117
This needs to be generalized to non-singleton indexes but, OTOH, it can be specialized to just one update, at least initially.
Overall, this ticket is a big chunk of work, but quite modular. A couple of parts, but probably intertwined with others, are crucial for performance of the simplified horde-ad.