nngeometry
nngeometry copied to clipboard
Error: `I do not know what to do with layer Embedding(50304, 512)`
First of all great library, I've always been looking for some ways to get jacobians and fisher information matrices for my PyTorch models. While the library is fine with my vision models based on simple convolutional networks, I find it harder to use with Huggingface pretrained models. To be clear, I believe the embedding layers are the culprit here.
I devised a dataloader taking text and returning a dictionary with "input_ids" and "attention_mask" which takes in a list of strings as input and yields a batch like a dictionary with the above keys and torch.Tensor of integer type as their values.
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from transformers.tokenization_utils import BatchEncoding
torch_model = GPTNeoXForCausalLM.from_pretrained(
pretrained_model_name_or_path=f"EleutherAI/pythia-70m-deduped",
revision=f"step1000",
cache_dir=cache_dir,
)
class FIMDataLoader(Dataset):
def __init__(self, text_list, tokenizer, max_length=128):
self.text_list = text_list
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.text_list)
def __getitem__(self, idx):
text = self.text_list[idx]
encoding = self.tokenizer(
text,
max_length=self.max_length,
truncation=True,
padding="max_length",
# max_length=self.max_length,
return_tensors="pt",
)
input_ids = encoding["input_ids"].squeeze()
attention_mask = encoding["attention_mask"].squeeze()
return input_ids, attention_mask
def collate_fn(batch):
input_ids, attention_mask = zip(*batch)
return BatchEncoding(
{
"input_ids": torch.stack(input_ids),
"attention_mask": torch.stack(attention_mask),
}
)
def create_dataloader(text_list, tokenizer, batch_size, max_length, shuffle=False):
dataset = FIMDataLoader(text_list, tokenizer, max_length)
dataloader = DataLoader(
dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn
)
return dataloader
then I instanciate the dataloader
texts_list = ["The cat is on the table", "Alice and Bob are friends"]
dataloader = create_dataloader(
texts_list, tokenizer, batch_size=1, max_length=128, shuffle=False
)
For a model with a total of 70m parameters, having the entire Fisher matrix in memory is prohibitive, so I have chosen to use the diagonal with storage proportional to number of parameters, by choosing the PMatDiag
representation you kindly provided in your library.
I thought this would give me the diagonal of the Fisher information matrix, right? However, an error appears that seems related with LayerCollection creation.
from nngeometry.metrics import FIM
from nngeometry.object import PMatDiag
FIM(
model=torch_model,
loader=dataloader,
representation=PMatDiag,
n_output=1,
device="cpu",
)
but I get the following error:
Exception Traceback (most recent call last)
Cell In[93], line 6
3 from nngeometry.metrics import FIM
4 from nngeometry.object import PMatDiag
----> 6 F_ekfac = FIM(
7 model=torch_model,
8 loader=dataloader,
9 representation=PMatDiag,
10 n_output=1,
11 device=\"cpu\",
12 )
File ~/opt/miniconda3/envs/pythia/lib/python3.10/site-packages/nngeometry/metrics.py:147, in FIM(model, loader, representation, n_output, variant, device, function, layer_collection)
144 return model(d[0].to(device))
146 if layer_collection is None:
--> 147 layer_collection = LayerCollection.from_model(model)
149 if variant == 'classif_logits':
151 def function_fim(*d):
File ~/opt/miniconda3/envs/pythia/lib/python3.10/site-packages/nngeometry/layercollection.py:50, in LayerCollection.from_model(model, ignore_unsupported_layers)
48 elif not ignore_unsupported_layers:
49 if len(list(mod.children())) == 0 and len(list(mod.parameters())) > 0:
---> 50 raise Exception('I do not know what to do with layer ' + str(mod))
52 return lc
Exception: I do not know what to do with layer Embedding(50304, 512)"
}
```
It looks like the reason why I get this error has to do with the Embedding layers (there are two embedding layers, one to convert token ids from the vocabulary space (size 50304) to the latent space (size 512) and another embedding layer at the end to do viceversa.
What should I do to have the FIM diagonal of all model parameters?
Many thanks, and again, great package.