adapters icon indicating copy to clipboard operation
adapters copied to clipboard

Bottleneck Configs do not work with `ln_before = True` and `init_weights = "mam_adapter"`

Open julian-fong opened this issue 4 months ago • 0 comments

Given that we specify a bottleneck config with layer normalization before the adapter bottle neck, and we specify the initial weights of the adapter to be of type mam_adapter, it will trigger an error.

Reproducible Code

from transformers import RobertaConfig
from adapters import AutoAdapterModel

config = RobertaConfig.from_pretrained(
    "roberta-base",
    num_labels=2,
)
model = AutoAdapterModel.from_pretrained(
    "roberta-base",
    config=config,
)

from adapters import BnConfig

config = BnConfig(mh_adapter=True, output_adapter = False, ln_before = True, reduction_factor=16, non_linearity="relu", init_weights="mam_adapter")
model.add_adapter("bottleneck_adapter", config=config)

This is because if we set ln_before = True, the first layer of seq_list for the will be a normalization layer. Then when we attempt to initialize the weights using the type mam_adapter, the kaiming uniform will be applied onto the first layer of the seq_list thinking it is $W_{down}$ but it is actually a normalization layer.

elif config["init_weights"] == "mam_adapter":
    with torch.no_grad():
        nn.init.kaiming_uniform_(self.adapter_down[0].weight, a=math.sqrt(5))
        nn.init.zeros_(self.adapter_up.weight)
        nn.init.zeros_(self.adapter_down[0].bias)
        nn.init.zeros_(self.adapter_up.bias)
        if self.use_gating:
            self.gate.apply(self.init_bert_weights)

julian-fong avatar Oct 19 '24 00:10 julian-fong