transformer_in_transformer_flax icon indicating copy to clipboard operation
transformer_in_transformer_flax copied to clipboard

Transformer in Transformer in JAX/Flax

This repository implements Transformer in Transformer, pixel level attention paired with patch level attention for image classification. It is heavily inspired by both lucidrains's Pytorch implementation and Google Brain's Vision Transformer repo.

AI Coffee Break with Letita

Install

$ pip install transformer-in-transformer-flax

Usage

from jax import random
from jax import numpy as jnp
from transformer_in_transformer_flax import TransformerInTransformer, TNTConfig

#example configuration for TNT-B
config = TNTConfig(
    num_classes = 1000,
    depth = 12,
    image_size = 224,
    patch_size = 16,
    transformed_patch_size = 4,
    inner_dim = 40,
    inner_heads = 4,
    inner_dim_head = 64,
    inner_r = 4,
    outer_dim = 640,
    outer_heads = 10,
    outer_dim_head = 64,
    outer_r = 4
)

rng = random.PRNGKey(seed=0)
model = TransformerInTransformer(config=config)
params = model.init(rng, jnp.ones((1, 224, 224, 3), dtype=config.dtype))
img = random.uniform(rng, (2, 224, 224, 3))
logits = model.apply(params, img) # (2, 1000)

Citations

@misc{han2021transformer,
    title   = {Transformer in Transformer}, 
    author  = {Kai Han and An Xiao and Enhua Wu and Jianyuan Guo and Chunjing Xu and Yunhe Wang},
    year    = {2021},
    eprint  = {2103.00112},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}