nngeometry icon indicating copy to clipboard operation
nngeometry copied to clipboard

Error: `I do not know what to do with layer Embedding(50304, 512)`

Open CarloNicolini opened this issue 1 year ago • 4 comments

First of all great library, I've always been looking for some ways to get jacobians and fisher information matrices for my PyTorch models. While the library is fine with my vision models based on simple convolutional networks, I find it harder to use with Huggingface pretrained models. To be clear, I believe the embedding layers are the culprit here.

I devised a dataloader taking text and returning a dictionary with "input_ids" and "attention_mask" which takes in a list of strings as input and yields a batch like a dictionary with the above keys and torch.Tensor of integer type as their values.

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from transformers.tokenization_utils import BatchEncoding

torch_model = GPTNeoXForCausalLM.from_pretrained(
    pretrained_model_name_or_path=f"EleutherAI/pythia-70m-deduped",
    revision=f"step1000",
    cache_dir=cache_dir,
)

class FIMDataLoader(Dataset):
    def __init__(self, text_list, tokenizer, max_length=128):
        self.text_list = text_list
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.text_list)

    def __getitem__(self, idx):
        text = self.text_list[idx]
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            truncation=True,
            padding="max_length",
            # max_length=self.max_length,
            return_tensors="pt",
        )
        input_ids = encoding["input_ids"].squeeze()
        attention_mask = encoding["attention_mask"].squeeze()
        return input_ids, attention_mask


def collate_fn(batch):
    input_ids, attention_mask = zip(*batch)

    return BatchEncoding(
        {
            "input_ids": torch.stack(input_ids),
            "attention_mask": torch.stack(attention_mask),
        }
    )


def create_dataloader(text_list, tokenizer, batch_size, max_length, shuffle=False):
    dataset = FIMDataLoader(text_list, tokenizer, max_length)
    dataloader = DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn
    )
    return dataloader

then I instanciate the dataloader

texts_list = ["The cat is on the table", "Alice and Bob are friends"]
dataloader = create_dataloader(
    texts_list, tokenizer, batch_size=1, max_length=128, shuffle=False
)

For a model with a total of 70m parameters, having the entire Fisher matrix in memory is prohibitive, so I have chosen to use the diagonal with storage proportional to number of parameters, by choosing the PMatDiag representation you kindly provided in your library.

I thought this would give me the diagonal of the Fisher information matrix, right? However, an error appears that seems related with LayerCollection creation.

from nngeometry.metrics import FIM
from nngeometry.object import PMatDiag

FIM(
    model=torch_model,
    loader=dataloader,
    representation=PMatDiag,
    n_output=1,
    device="cpu",
)

but I get the following error:

Exception                                 Traceback (most recent call last)
Cell In[93], line 6
      3 from nngeometry.metrics import FIM
      4 from nngeometry.object import PMatDiag
----> 6 F_ekfac = FIM(
      7     model=torch_model,
      8     loader=dataloader,
      9     representation=PMatDiag,
     10     n_output=1,
     11     device=\"cpu\",
     12 )

File ~/opt/miniconda3/envs/pythia/lib/python3.10/site-packages/nngeometry/metrics.py:147, in FIM(model, loader, representation, n_output, variant, device, function, layer_collection)
    144         return model(d[0].to(device))
    146 if layer_collection is None:
--> 147     layer_collection = LayerCollection.from_model(model)
    149 if variant == 'classif_logits':
    151     def function_fim(*d):

File ~/opt/miniconda3/envs/pythia/lib/python3.10/site-packages/nngeometry/layercollection.py:50, in LayerCollection.from_model(model, ignore_unsupported_layers)
     48     elif not ignore_unsupported_layers:
     49         if len(list(mod.children())) == 0 and len(list(mod.parameters())) > 0:
---> 50             raise Exception('I do not know what to do with layer ' + str(mod))
     52 return lc

Exception: I do not know what to do with layer Embedding(50304, 512)"
}
```

It looks like the reason why I get this error has to do with the Embedding layers (there are two embedding layers, one to convert token ids  from the vocabulary space (size 50304) to the latent space (size 512) and another embedding layer at the end to do viceversa.

What should I do to have the FIM diagonal of all model parameters?
Many thanks, and again, great package.

CarloNicolini avatar Jan 11 '24 21:01 CarloNicolini