Unused renaming scheme make cache access via renamed name fail
Even if a renaming is not used (model.language_model here):
from nnsight import LanguageModel
m = LanguageModel(
"yujiepan/llama-3.2-tiny-random",
rename={"model": "foo", "model.language_model": "foo"},
# rename={"model": "foo"}, # works
)
with m.trace("Hello, world!") as tracer:
cache = tracer.cache(modules=[m.foo.layers[0]])
print(cache.model.foo.layers[0].output)
It will make caching fail:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[7], line 10
8 with m.trace("Hello, world!") as tracer:
9 cache = tracer.cache(modules=[m.foo.layers[0]])
---> 10 print(cache.model.foo.layers[0].output)
File ~/projects/nnterp/.venv/lib/python3.10/site-packages/nnsight/intervention/tracing/tracer.py:156, in Cache.CacheDict.__getattr__(self, attr)
154 name = self._alias[attr]
155 name = name.removeprefix(".")
--> 156 return self.__getattr__(name)
157 else:
158 raise AttributeError(f"'{attr}' module path was never cached. '{self.__class__.__name__}' has no matching attribute.")
File ~/projects/nnterp/.venv/lib/python3.10/site-packages/nnsight/intervention/tracing/tracer.py:158, in Cache.CacheDict.__getattr__(self, attr)
156 return self.__getattr__(name)
157 else:
--> 158 raise AttributeError(f"'{attr}' module path was never cached. '{self.__class__.__name__}' has no matching attribute.")
AttributeError: 'model.language_model' module path was never cached. 'CacheDict' has no matching attribute.
bump @AdamBelfki3. Idk if it's an easy fix but that'd be great to have this fixed upstream for nnterp
@Butanium try this branch and let me know if it handles your use case
@AdamBelfki3 thx, still some problem with path renaming:
from nnsight import LanguageModel
m = LanguageModel("gpt2", rename={"transformer.h": "layers"})
with m.trace("Hi") as t:
cache = t.cache(modules=[m.layers[0]]).save()
print("Keys:", list(cache.keys()))
print("Trying: cache.layers[0].output")
print(cache.model.layers[0].output) # CRASHES
@Butanium The root of the cache entries is cache.model, you have to call print(cache.model.layers[0].output) instead.
This is because cache is just a dictionary and cache.model is what actually contains the top-level inputs and outputs of the model
Oh yes sorry Claude was being dumb
@AdamBelfki3 i updated my example which also crashes
are you on the right branch? I just ran it and it worked for me
@Butanium
@AdamBelfki3 sorry this is the minimal dict that crashes for me!
{
"transformer": "model",
"h": "layers",
"model.layers": "layers",
}