nnsight icon indicating copy to clipboard operation
nnsight copied to clipboard

Typing error using LanguageModel.trace() as a context manager

Open chanind opened this issue 1 year ago • 0 comments

The following code works, but pyright/pylance gives a typing error:


from nnsight import LanguageModel

model = LanguageModel('google/gemma-2-2b')

with model.trace('hi'):
    acts = model.model.layers[0].output.save()

The errors are the following:

Object of type "InterventionProxy" cannot be used with "with" because it does not implement __enter__
  Attribute "__enter__" is unknown

Object of type "InterventionProxy" cannot be used with "with" because it does not implement __exit__

It looks like this because LanguageModel.__new__() is typed as returning either a LanguageModel or an Envoy. Envoy.__call__() returns an InterventionProxy, which is not a contextmanager. Since the type checker doesn't know if the model is a LanguageModel or an Envoy, it has to assume that model might be an Envoy and thus can't use with model.trace() as a contextmanager. It seems like __new__() should be typed to just return a LanguageModel, no?

chanind avatar Oct 12 '24 20:10 chanind