nnsight icon indicating copy to clipboard operation
nnsight copied to clipboard

IndexError when flat inputs are concatenated during trace

Open austinleedavis opened this issue 7 months ago • 0 comments

Description

The NNSight object raises an IndexError when using unbatched token IDs as input while tracing in a loop. This bug is an oppressive landmine and the error message is not very helpful. It be nice if the trace invocation checks if the inputs are correctly batched using len(input_ids.shape) before applying the tensor concatenation. A simple solution is to use tracer.invoke(input_ids.unsqueeze(0)).

Root Cause

When an NNSight object batches inputs it uses the torch.concatenate() method to stack the input_id tensors along dimension zero (0). However, if a single dimension tensor is used as input, e.g., len(input_id.shape)==1, then concatenation appends the subsequent inputs to the original input_id tensor rather than stacking it. The downstream forward pass can potentially raise an IndexError after multiple concatenations because the resulting length of the "batched" inputs exceed the model context window. So, the forward pass will raise an IndexError when the inputs are being embedded because the input id indices exceed the number of columns in the embedding matrix.

Working Example

input_ids = torch.tensor([[1]]) # <-- shape = torch.Size([1, 1])
model = NNSight(model)
with model.trace() as tracer:
    for i in range(model.transformer.wpe.weight.shape[0]):
        with tracer.invoke(input_ids):
            pass 

Failing Example

input_ids = torch.tensor([1]) # <-- shape = torch.Size([1])
model = NNSight(model)
with model.trace() as tracer:
    for i in range(model.transformer.wpe.weight.shape[0]):
        with tracer.invoke(input_ids):
            pass 

Info

  • nnsight==0.2.16
  • torch==2.3.1
  • transformers==4.40.1

austinleedavis avatar Jul 20 '24 00:07 austinleedavis