Unofficial Implementation of MLP-Mixer, gMLP, resMLP, Vision Permutator, S2MLP, S2MLPv2, RaftMLP, HireMLP, ConvMLP, SparseMLP, ConvMixer, AS-MLP, SwinMLP, RepMLPNet, WaveMLP, MorphMLP, DynaMixer, MS-MLP, Sequencer2D in Jittor and PyTorch. GFNet and CycleMLP in PyTorch.

What's New

Are we ready for a new paradigm shift? A Survey on Visual Deep MLP (paper).

trunc_normal_ is supported for Jittor (Jittor vision >= 1.3.15)! (shown in ./models_jittor/utils/init.py)

import math
import warnings
import jittor as jt

def trunc_normal_(var, mean=0., std=1., a=-2., b=2.):
    # type: (jt.jittor_core.Var, float, float, float, float) -> jt.jittor_core.Var
    r"""Fills the input jt.jittor_core.Var with values drawn from a truncated
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq \text{mean} \leq b`.
        var: an n-dimensional `jt.jittor_core.Var` 
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution
        a: the minimum cutoff value
        b: the maximum cutoff value
        >>> w = torch.empty(3, 5)
        >>> nn.init.trunc_normal_(w)
    return _no_grad_trunc_normal_(var, mean, std, a, b)

def _no_grad_trunc_normal_(var, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",

    # Values are generated by using a truncated uniform distribution and
    # then using the inverse CDF for the normal distribution.
    # Get upper and lower cdf values
    l = norm_cdf((a - mean) / std)
    u = norm_cdf((b - mean) / std)

    # Uniformly fill tensor with values from [l, u], then translate to
    # [2l-1, 2u-1].
    var.uniform_(low=2 * l - 1, high=2 * u - 1)

    # Use inverse cdf transform for normal distribution to get truncated
    # standard normal
    var = var.erfinv()

    # Transform to proper mean, std
    var = var.multiply(std * math.sqrt(2.))
    var = var.add(mean)

    # Clamp to ensure it's in the proper range
    var = var.clamp(min_v=a, max_v=b)
    return var

Rearrange, Reduce in einops for Jittor is support ! Easier to convert Transformer-based and MLP-based models from PyTorch to Jittor!

  • from .einops_my.layers.jittor import Rearrange, Reduce (shown in ./models_jittor/raft_mlp.py, ./models_jittor/sparse_mlp.py)
  • rearrange, repeat, reduce, parse_shape for Jitter
import jittor as jt
import numpy as np
from einops_my.einops import rearrange, reduce, repeat, parse_shape

x = jt.zeros([2, 3, 5, 7])
parse_shape(x, 'batch _ h w')	# {'batch': 2, 'h': 5, 'w': 7}

images = jt.array([np.random.randn(30, 40, 3) for _ in range(32)])
rearrange(images, 'b h w c -> b h w c').shape		# [32,30,40,3,]

image = jt.array(np.random.randn(30, 40))
repeat(image, 'h w -> h w c', c=3).shape				# [30,40,3,]

x = jt.array(np.random.randn(2, 32, 64))
y = reduce(x, 't b c -> b c', 'max')						# [32,64,]												


  • Jittor and Pytorch implementaion of gMLP

import jittor as jt
n,c,h,w = 2,10,4,5
x = jt.random((n,c,h,w))

def shift(x, h_offset, h_stride, h_cycle, w_offset, w_stride, w_cycle):
    return x.reindex([n,c,h,w], ["i0", "i1", 

# y = shift(x, -1, 1, 3, 0, 0, 1)
y = shift(x, 0, 0, 1, -1, 1, 3)


import jittor as jt
from models_jittor import gMLPForImageClassification as gMLP_jt
from models_jittor import ResMLPForImageClassification as ResMLP_jt
from models_jittor import MLPMixerForImageClassification as MLPMixer_jt
from models_jittor import ViP as ViP_jt
from models_jittor import S2MLPv2 as S2MLPv2_jt
from models_jittor import S2MLPv1_deep as S2MLPv1_deep_jt 
from models_jittor import ConvMixer as ConvMixer_jt
from models_jittor import convmlp_s as ConvMLP_s_jt 
from models_jittor import convmlp_l as ConvMLP_l_jt 
from models_jittor import convmlp_m as ConvMLP_m_jt 
from models_jittor import RaftMLP as RaftMLP_jt
from models_jittor import SparseMLP as SparseMLP_jt
from models_jittor import HireMLP as HireMLP_jt
from models_jittor import AS_MLP as AS_MLP_jt
from models_jittor import SwinMLP as SwinMLP_jt
from models_jittor import WaveMLP as WaveMLP_jt
from models_jittor import MS_MLP as MS_MLP_jt
from models_jittor import MorphMLP as MorphMLP_jt
from models_jittor import DynaMixer as DynaMixer_jt
from models_jittor import Sequencer2D as Sequencer2D_jt

model_jt = MLPMixer_jt(

images = jt.randn(8, 3, 224, 224)
with jt.no_grad():
    output = model_jt(images)
print(output.shape) # (8, 1000)


import torch
from models_pytorch import gMLPForImageClassification as gMLP_pt
from models_pytorch import ResMLPForImageClassification as ResMLP_pt
from models_pytorch import MLPMixerForImageClassification as MLPMixer_pt
from models_pytorch import ViP as ViP_pt
from models_pytorch import S2MLPv2 as S2MLPv2_pt 
from models_pytorch import S2MLPv1_deep as S2MLPv1_deep_pt
from models_pytorch import ConvMixer as ConvMixer_pt 
from models_pytorch import convmlp_s as ConvMLP_s_pt 
from models_pytorch import convmlp_l as ConvMLP_l_pt 
from models_pytorch import convmlp_m as ConvMLP_m_pt 
from models_pytorch import RaftMLP as RaftMLP_pt
from models_pytorch import SparseMLP as SparseMLP_pt
from models_pytorch import HireMLP as HireMLP_pt
from models_pytorch import GFNet as GFNet_pt
from models_pytorch import CycleMLP_B2 as CycleMLP_B2_pt
from models_pytorch import AS_MLP as AS_MLP_pt
from models_pytorch import SwinMLP as SwinMLP_pt
from models_pytorch import create_RepMLPNet_B224, create_RepMLPNet_B256
from models_pytorch import WaveMLP as WaveMLP_pt
from models_pytorch import MS_MLP as MS_MLP_pt
from models_pytorch import MorphMLP as MorphMLP_pt
from models_pytorch import DynaMixer as DynaMixer_pt
from models_pytorch import Sequencer2D as Sequencer2D_pt

model_pt = ViP_pt(
    segments = 16,
    weighted = True

images = torch.randn(8, 3, 224, 224)

with torch.no_grad():
    output = model_pt(images)
print(output.shape) # (8, 1000)

############################## Non-square images and patch sizes #########################

model_jt = ViP_jt(
    image_size=(224, 112),
    patch_size=(16, 8),
    segments = 16,
    weighted = True
images = jt.randn(8, 3, 224, 112)
with jt.no_grad():
    output = model_jt(images)
print(output.shape) # (8, 1000)

############################## 2 Stages S2MLPv2 #########################
model_pt = S2MLPv2_pt(
    in_channels = 3,
    image_size = (224,224),
    patch_size = [(7,7), (2,2)],
    d_model = [192, 384],
    depth = [4, 14],
    num_classes = 1000, 
    expansion_factor = [3, 3]

############################## ConvMLP With Pretrain Params #########################
model_jt = ConvMLP_s_jt(pretrained = True, num_classes = 1000)

############################## RaftMLP #########################
model_jt = RaftMLP_jt(
        layers = [
            {"depth": 12,
            "dim": 768,
            "patch_size": 16,
            "raft_size": 4}
        gap = True

############################## SparseMLP #########################
model_pt = SparseMLP_pt(
        expansion_factor = 2,
        patcher_norm= True

############################## HireMLP #########################

model_pt = HireMLP_pt(
        d_model=[64, 128, 320, 512],
        h = [4,3,3,2],
        w = [4,3,3,2],
        cross_region_step = [2,2,1,1],
        cross_region_interval = 2,
        expansion_factor = 2,
        patcher_norm = True,
    	padding_type = 'circular',

############################## GFNet #########################
model_pt = GFNet_pt()

############################## CycleMLP #########################
model_pt = CycleMLP_B2_pt()

############################## AS-MLP #########################
model_pt = AS_MLP_pt()

############################## WaveMLP #########################
model_pt = WaveMLP_pt('M')

############################## MS-MLP #########################
model_pt = MS_MLP_pt()

############################## MorphMLP #########################
model_pt = MorphMLP_pt('B')

############################## DynaMixer #########################
model_pt = DynaMixer_pt('M')

############################## Sequencer2D #########################
model_pt = Sequencer2D_pt('M')


