lightning-thunder
lightning-thunder copied to clipboard
Support more indexing operators (`index_copy` and `index_add`)
🚀 Feature
Motivation
Currently only torch.index_add is implemented.
Pitch
import torch
import thunder
def index_copy_method(x, t, index): return x.index_copy(0, index, t)
def index_add_method(x, t, index): return x.index_add(0, index, t)
def index_copy_method_inplace(x, t, index): return x.index_copy_(0, index, t)
def index_add_method_inplace(x, t, index): return x.index_add_(0, index, t)
def index_copy_function(x, t, index): return torch.index_copy(x, 0, index, t)
# this one works
# def index_add_function(x, t, index): return torch.index_add(x, 0, index, t)
x = torch.zeros(5, 3)
t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float)
index = torch.tensor([0, 4, 2])
for fn in [x for x in globals().values() if hasattr(x, "__name__") and x.__name__.startswith("index_")]:
fn = thunder.jit(fn)
try:
out = fn(x, t, index)
print(f"{fn.__name__} worked:\n{out}")
except Exception as e:
print(f"{fn.__name__} failed: {str(e)}")
The inplace variants probably need to be functionalized. Otherwise we should raise a NotImplementedError for them
Additional context
Lit-GPT uses index_copy_ but since the target is a zeros tensor, it can replace it with index_add.