compact-multi-head-self-attention-pytorch icon indicating copy to clipboard operation
compact-multi-head-self-attention-pytorch copied to clipboard

A PyTorch implementation of the Compact Multi-Head Self-Attention Mechanism from the paper: "Low Rank Factorization for Compact Multi-Head Self-Attention"

Build Status Codacy Badge Codacy Badge

PyTorch Implementation of Low Rank Factorization for Compact Multi-Head Self-Attention (LAMA)

This is a PyTorch implementation of the L ow Rank F a ctorization for Compact M ulti-Head A ttention (LAMA) mechanism and the corresponding pooler introduced in the paper: "Low Rank Factorization for Compact Multi-Head Self-Attention".

Figure 1 from Low Rank Factorization for Compact Multi-Head Self-Attention.

Note: I am not one of the authors on the paper.

Usage

The only dependency is PyTorch. Installation instructions can be found here.

LAMA

import torch
from modules.lama import LAMA

num_heads = 8      # Number of attention heads
input_dim = 768    # Dimension of each tokens hidden representation
batch_size = 16    # Number of sentences/documents in the mini-batch
max_seq_len = 100  # Maximum length of the input sequence

# Create a random input sequence
inputs = torch.randn(batch_size, max_seq_len, input_dim)  
# Optionally, you can provide a mask over timesteps (e.g., for padding tokens)
# Size: (batch_size, max_seq_len), 0 where timesteps should be masked and 1 otherwise
mask = torch.ones(batch_size, max_seq_len)
mask[:, -1] = 0

# Initialize the attention mechanism
lama = LAMA(num_heads, input_dim)
output = lama(inputs, mask)

assert output.size() == (batch_size, num_heads, max_seq_len)

LAMAEncoder

import torch
from modules.lama_encoder import LAMAEncoder

num_heads = 8      # Number of attention heads
input_dim = 768    # Dimension of each tokens hidden representation
batch_size = 16    # Number of sentences/documents in the mini-batch
max_seq_len = 100  # Maximum length of the input sequence

# Create a random input sequence
inputs = torch.randn(batch_size, max_seq_len, input_dim)  
# Optionally, you can provide a mask over timesteps (e.g., for padding tokens)
# Size: (batch_size, max_seq_len), 0 where timesteps should be masked and 1 otherwise
mask = torch.ones(batch_size, max_seq_len)
mask[:, -1] = 0

# Initialize the encoder
lama_encoder = LAMAEncoder(num_heads, input_dim)
output = lama_encoder(inputs, mask)

assert output.size() == (batch_size, num_heads, input_dim)

# If output_dim is not None (default), the "structured sentence embedding" is flattened by concatenation and projected by a linear layer into a vector of this size
lama_encoder = LAMAEncoder(num_heads, input_dim, output_dim=128)
output = lama_encoder(inputs, mask)

assert output.size() == (batch_size, 128)