cosformer-pytorch icon indicating copy to clipboard operation
cosformer-pytorch copied to clipboard

Unofficial PyTorch implementation of the paper "cosFormer: Rethinking Softmax In Attention".

cosFormer-PyTorch

An unofficial PyTorch implementation of the model proposed in the paper cosFormer: Rethinking Softmax In Attention (Submitted to ICLR 2022).

Image stolen from the paper.

Table of contents

  • ➤ Paper Summary
  • ➤ Installation
  • ➤ Usage
  • ➤ Citations

Paper Summary

As many others, this paper buids on recent work on linear attention that is calculated as instead of , where is a kernel function. This reduces the complexity from to . The authors propose to extend this mechanism by including relative distance information in the Q, K product as . After expanding the trigonometric identity, the full equation becomes:

where etc.

As the author of this repo possesses neither the time nor the ability, only the non-causal version of this approach is implemented.

Installation

$ git clone https://github.com/davidsvy/cosformer-pytorch
$ cd cosformer-pytorch
$ pip install -r requirements.txt

Usage

from models.kernel_transformer import Kernel_transformer
import torch

model = Kernel_transformer(
    # Linear attention args:
    use_cos=True,         # Whether to use the cosine reweighting mechanism prposed in the paper.
    kernel='relu',        # Kernel that approximates softmax. Available options are 'relu' and 'elu'.
    denom_eps=1e-5,       # Added to the denominator of linear attention for numerical stability.
    # If use_cos=True & kernel='relu' the model is equivalent to https://openreview.net/pdf?id=Bl8CQrx2Up4
    # If use_cos=False & kernel='elu' the model is equivalent to https://arxiv.org/pdf/2006.16236.pdf
    # Vanilla transformer args:
    d_model=512,
    n_heads=8, 
    n_layers=6,
    n_emb=20000, 
    ffn_ratio=4, 
    rezero=True,          # If True, use the ReZero architecture from https://arxiv.org/pdf/2003.04887.pdf, else the Pre-LN architecture from https://arxiv.org/pdf/2002.04745.pdf
    ln_eps=1e-5, 
    bias=False, 
    dropout=0.2, 
    max_len=1024, 
    xavier=True
)

input_ids = torch.randint(0, 20000, [4, 100])
lengths = torch.randint(1, 100, [4])
attention_mask = torch.arange(100)[None, :] < lengths[:, None]

output = model(
    input_ids=input_ids,
    lengths=lengths,
    attention_mask=attention_mask,
)

Citations

@inproceedings{
anonymous2022cosformer,
title={cosFormer: Rethinking Softmax In Attention},
author={Anonymous},
booktitle={Submitted to The Tenth International Conference on Learning Representations },
year={2022},
url={https://openreview.net/forum?id=Bl8CQrx2Up4},
note={under review}
}