nnsight
nnsight copied to clipboard
IndexError when flat inputs are concatenated during trace
Description
The NNSight
object raises an IndexError
when using unbatched token IDs as input while tracing in a loop. This bug is an oppressive landmine and the error message is not very helpful. It be nice if the trace invocation checks if the inputs are correctly batched using len(input_ids.shape)
before applying the tensor concatenation. A simple solution is to use tracer.invoke(input_ids.unsqueeze(0))
.
Root Cause
When an NNSight
object batches inputs it uses the torch.concatenate()
method to stack the input_id
tensors along dimension zero (0). However, if a single dimension tensor is used as input, e.g., len(input_id.shape)==1
, then concatenation appends the subsequent inputs to the original input_id
tensor rather than stacking it. The downstream forward pass can potentially raise an IndexError
after multiple concatenations because the resulting length of the "batched" inputs exceed the model context window. So, the forward pass will raise an IndexError
when the inputs are being embedded because the input id indices exceed the number of columns in the embedding matrix.
Working Example
input_ids = torch.tensor([[1]]) # <-- shape = torch.Size([1, 1])
model = NNSight(model)
with model.trace() as tracer:
for i in range(model.transformer.wpe.weight.shape[0]):
with tracer.invoke(input_ids):
pass
Failing Example
input_ids = torch.tensor([1]) # <-- shape = torch.Size([1])
model = NNSight(model)
with model.trace() as tracer:
for i in range(model.transformer.wpe.weight.shape[0]):
with tracer.invoke(input_ids):
pass
Info
- nnsight==0.2.16
- torch==2.3.1
- transformers==4.40.1