lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

Support more indexing operators (`index_copy` and `index_add`)

Open carmocca opened this issue 2 years ago • 0 comments

🚀 Feature

Motivation

Currently only torch.index_add is implemented.

Pitch

import torch
import thunder

def index_copy_method(x, t, index):         return x.index_copy(0, index, t)
def index_add_method(x, t, index):          return x.index_add(0, index, t)
def index_copy_method_inplace(x, t, index): return x.index_copy_(0, index, t)
def index_add_method_inplace(x, t, index):  return x.index_add_(0, index, t)
def index_copy_function(x, t, index):       return torch.index_copy(x, 0, index, t)
# this one works
# def index_add_function(x, t, index):      return torch.index_add(x, 0, index, t)

x = torch.zeros(5, 3)
t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float)
index = torch.tensor([0, 4, 2])

for fn in [x for x in globals().values() if hasattr(x, "__name__") and x.__name__.startswith("index_")]:
    fn = thunder.jit(fn)
    try:
        out = fn(x, t, index)
        print(f"{fn.__name__} worked:\n{out}")
    except Exception as e:
        print(f"{fn.__name__} failed: {str(e)}")

The inplace variants probably need to be functionalized. Otherwise we should raise a NotImplementedError for them

Additional context

Lit-GPT uses index_copy_ but since the target is a zeros tensor, it can replace it with index_add.

carmocca avatar Sep 26 '23 23:09 carmocca