nnsight icon indicating copy to clipboard operation
nnsight copied to clipboard

Envoy.iter!

Open JadenFiotto-Kaufman opened this issue 4 months ago • 0 comments

New Feature ! @AdamBelfki3 @cadentj

New paradigm to specify module iterations.

Here is me specifying I want an intervention to apply to all iterations using a global .all():

from nnsight import LanguageModel
from nnsight.intervention import InterventionProtocol
from nnsight import list
import torch

model =  LanguageModel("meta-llama/Meta-Llama-3.1-8B", device_map="auto", torch_dtype=torch.bfloat16)

from nnsight import list

with model.generate("hello world", max_new_tokens=5):
    values = list().save()
        
    model.all()
        
    values.append(model.lm_head.output)

print(len(values.value)) # Prints 5

Now if I only wanted to do some interventions every iteration, I can use it like a context manager:


with model.generate("hello world", max_new_tokens=5):
    values = list().save()
    other_values = list().save()

    with model.all():
    
        values.append(model.lm_head.output)
            
    other_values.append(model.lm_head.output)

print(len(values.value)) # Prints 5
print(len(other_values.value)) # Prints 1

.all() is an alias for .iter[:]. Yes thats right, you can specify a specific iteration with an int, multiple iterations with a list of ints, or a range using a slice:

with model.generate("hello world", max_new_tokens=5):
    values = list().save()
    other_values = list().save()

    with model.iter[2:4]:
    
        values.append(model.lm_head.output)
            
    other_values.append(model.lm_head.output)

print(len(values.value)) # Prints 2
print(len(other_values.value)) # Prints 1
with model.generate("hello world", max_new_tokens=5):
    values = list().save()
    other_values = list().save()

    with model.iter[[0,1,4]]:
    
        values.append(model.lm_head.output)
            
    other_values.append(model.lm_head.output)

print(len(values.value)) # Prints 3
print(len(other_values.value)) # Prints 1

This also works inline:

with model.generate("hello world", max_new_tokens=5):
    values = list().save()
    other_values = list().save()

    values.append(model.lm_head.iter[2:4].output)
            
    other_values.append(model.lm_head.output)

print(len(values.value)) # Prints 2
print(len(other_values.value)) # Prints 1

same thing for .all() applies to .next()

JadenFiotto-Kaufman avatar Oct 13 '24 04:10 JadenFiotto-Kaufman