lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Advanced Indexing with sequences

Open k223kim opened this issue 1 year ago • 8 comments

🚀 Feature

This is a follow up to PR #710.

Currently, thunder indexing supports the following (credit to @mruberry):

  • basic indexing
  • advanced indexing with a sequence of 0D or 1D integer tensors (possibly starting with an ellipsis)
  • advanced indexing with a single / or sequence of list (e.g. a[[1,2]], a[[1,2], [1,2]])
  • indexing with a sequence that consists of a single number and...
  • combination of basic indexing and advanced indexing that Thunder supports (e.g. a[slice(1,3), [1,2]] , a[None, [2,3]])

I am trying to extend this to support other sequences such as a[(1,2),]. I am seeking for advices and suggestions regarding what other indexing we want to support in Thunder.

cc. @mruberry , @t-vi

k223kim avatar Jul 14 '24 13:07 k223kim

To my mind, we should support direct indexing with a single int | slice | Ellipsis | list[int] | Tensor | None and with a tuple[int | slice | Ellipsis | Sequence[int] | Tensor | None, ...] for now. (i.e. I am less set on supporting list[tuple[int, ...]]). This would also allow us to have a clearcut explanation of what's supported and what not. WDYT?

t-vi avatar Jul 15 '24 14:07 t-vi

Thunder can conceptually support any indexing that doesn't use a boolean tensor in the key. It may be that some keys are easier to support than others, of course, and we should try to update the documentation/comments to reflect what is and isn't supported and why.

The clang indexing operations should (probably) be consistent with NumPy, but the torch indexing operations would, ideally, be consistent with PyTorch's behavior, and that may require clang indexing operations with more flexibility. That is, both PyTorch-style and NumPy-style indexing should be implementable, and the clang operations should be sensible and facilitate those implementations.

I don't know exactly how NumPy's and PyTorch's indexing diverges. It might be that there is a clear separation between the two, or maybe PyTorch just has a series of difficult to emulate bugs with no logic to them. It would be nice to understand this better, and to enable tests for the divergent behavior. We currently don't run many indexing test cases because of this divergence, but we should probably just tests those cases vs. NumPy, and later test a thunder.torch.getitem vs torch and not NumPy.

mruberry avatar Jul 15 '24 15:07 mruberry

Thanks all! I think there could be two steps:

  • Figure out how far we want to cover the NumPy's behavior in indexing.
  • See if there are any divergence with PyTorch-style indexing within those NumPy indexing we decide to support.

@mruberry Mike, do you have any thoughts regarding how far we want to cover regarding NumPy's indexing? Once we have that decided, I can do some research to see if there is any divergence with PyTorch.

k223kim avatar Jul 15 '24 15:07 k223kim

@mruberry Mike, do you have any thoughts regarding how far we want to cover regarding NumPy's indexing?

Ideally we'd support all of it (except boolean tensors in the key)

mruberry avatar Jul 15 '24 15:07 mruberry

@mruberry Ok. Regarding the advanced indexing in NumPy (referring to the reference you provided earlier), advanced indexing is defined as the following:

Advanced indexing is triggered when the selection object, obj, is a non-tuple sequence object, an ndarray (of data type integer or bool), or a tuple with at least one sequence object or ndarray (of data type integer or bool)

Just to make sure I am not misunderstanding anything, this includes the following indexing:

  • non-tuple sequence object (e.g. a[[1,2]], a[[1,2], [1,2]])
  • ndarray would be us using Tensors to index
  • tuple with at least one sequence object or ndarray (which I am not sure if I am understanding it properly): (e.g. a[([1,2])] and maybe a[([1,2], [1,2])] and a[((1,2), (1,2))]?) (Does this also include a[(1,2), (1,2)]?)

Would appreciate if you can educate me on this. Just want to make sure the next indexing PR supports what we want it to support. (Also making sure it aligns with Tom's comment above!)

k223kim avatar Jul 15 '24 15:07 k223kim

@mruberry Ok. Regarding the advanced indexing in NumPy (referring to the reference you provided earlier), advanced indexing is defined as the following:

Advanced indexing is triggered when the selection object, obj, is a non-tuple sequence object, an ndarray (of data type integer or bool), or a tuple with at least one sequence object or ndarray (of data type integer or bool)

Just to make sure I am not misunderstanding anything, this includes the following indexing:

  • non-tuple sequence object (e.g. a[[1,2]], a[[1,2], [1,2]])
  • ndarray would be us using Tensors to index
  • tuple with at least one sequence object or ndarray (which I am not sure if I am understanding it properly): (e.g. a[([1,2])] and maybe a[([1,2], [1,2])] and a[((1,2), (1,2))]?) (Does this also include a[(1,2), (1,2)]?)

Yep! That last one is equivalent to ((1, 2), (1, 2)).

It might be easiest to implement more advanced indexing features in isolation, then look at combined cases, probably starting with the combined indexing case where all advanced indexing elements are contiguous in the key. I think that case can be rewritten as an advanced index operation followed by a basic indexing operation?

mruberry avatar Jul 15 '24 16:07 mruberry

Yes, absolutely! Let me focus on the more advanced indexing features first as you have mentioned and follow up with the combination of basic and advanced indexing. Thank you for your explanation. Will update with a draft PR!

k223kim avatar Jul 15 '24 17:07 k223kim

The clang indexing operations should (probably) be consistent with NumPy, but the torch indexing operations would, ideally, be consistent with PyTorch's behavior, and that may require clang indexing operations with more flexibility.

As @lantiga commented in #710, we should keep in mind that our main/currently only use is PyTorch, so I would not sweat numpy compat a lot (but maybe leave a comment if we know something is not 100% numpy compatible).

That said, it's probably not important whether we support ignoring the 2nd+ ... (PyTorch) or error (NumPy). :wink:

t-vi avatar Jul 15 '24 18:07 t-vi