TensorComprehensions
TensorComprehensions copied to clipboard
Indirections on LHS [for indexing grad]
the reason for having the LHS indirection is because let's say someone has the lookup table:
def lut(float(B, R) M, int32(B, N) I) -> (O) {
O(b, n) +=! M(I(b, n), r)
}
now they want to define the gradient for this, the most intuitive way to write the TC for backwards would be
M(I(b, n), r) = O_grad(b, r)
from @ftynse
- [ ] without other uses of M, we cannot infer its size in the first dimension, even with a where clause; this is not a problem for input tensors whose size is specified
- [x] Halide->Polyhedral lowering pass should differentiate between must writes (all writes are must-writes now) and may writes
- [x] so should polyhedral dependence analysis
- [ ] advanced transformations (reductions, shared memory) should be disabled
- [ ] codegen should properly emit LHS indirections; hopefully this just works thanks to Halide
cc @martinraison who mentioned this case
cc @ftynse for thoughts on this.
@prigoyal We need to extend the gradient to a full TC definition with header to better see which features are not supported. Out of my head:
- [ ] without other uses of
M, we cannot infer its size in the first dimension, even with awhereclause; this is not a problem for input tensors whose size is specified - [ ] Halide->Polyhedral lowering pass should differentiate between must writes (all writes are must-writes now) and may writes
- [ ] so should the polyhedral dependence analysis
- [ ] advanced transformations (reductions, shared memory) should be disabled
- [ ] codegen should properly emit LHS indirections; hopefully this just works thanks to Halide
thanks @ftynse , moving your comment to the first comment as a checklist item
Certainly this is just basic flow support. Making this run fast is a different story...
Duplicate of #9 ?
I started working on this a while ago, but got stuck for the reason described in there.
#9 seems to be a broader issue... On the other hand, indirection requires may/must writes which is not necessarily the case for other expressions
The general scattering problem is really no worse than simple indirection on the LHS. Of all the computed values you could put there, a load is the least well-behaved. It's totally unbounded, unordered, etc.
What I meant was that scattering may be better in some cases :)
Ah, indeed. Scattering to affine or constant addresses is way easier.