EinMix: universal toolkit for advanced MLP architectures¶
Recent progress in MLP-based architectures demonstrated that very specific MLPs can compete with convnets and transformers (and even outperform them).
EinMix allows writing such architectures in a more uniform and readable way.
EinMix — building block of MLPs¶
from einops.layers.torch import EinMix as Mix
Logic of EinMix is very close to the one of einsum
.
If you're not familiar with einsum, follow these guides first:
- https://rockt.github.io/2018/04/30/einsum
- https://towardsdatascience.com/einsum-an-underestimated-function-99ca96e2942e
- https://theaisummer.com/einsum-attention/
Einsum uniformly describes a number of operations, however EinMix
is defined slightly differently.
Here is a linear layer, a common block in sequence modelling (e.g. in NLP/speech), written with einsum
weight = <...create tensor...>
result = torch.einsum('tbc,cd->tbd', embeddings, weight)
EinMix counter-part is:
mix_channels = Mix('t b c -> t b c_out', weight_shape='c c_out', ...)
result = mix_channels(embeddings)
Main differences compared to plain einsum
are:
- layer takes care of the weight initialization & management hassle
- weight is not in the comprehension
We'll discuss other changes a bit later, now let's implement ResMLP
# let's start
import torch
from torch import nn
# No norm layer
class Affine(nn.Module):
def __init__(self, dim):
super().__init__()
self.alpha = nn.Parameter(torch.ones(dim))
self.beta = nn.Parameter(torch.zeros(dim))
def forward(self, x):
return self.alpha * x + self.beta
class Mlp(nn.Module):
def __init__(self, dim):
super().__init__()
self.fc1 = nn.Linear(dim, 4 * dim)
self.act = nn.GELU()
self.fc2 = nn.Linear(4 * dim, dim)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
class ResMLP_Blocks(nn.Module):
def __init__(self, nb_patches, dim, layerscale_init):
super().__init__()
self.affine_1 = Affine(dim)
self.affine_2 = Affine(dim)
self.linear_patches = nn.Linear(nb_patches, nb_patches) #Linear layer on patches
self.mlp_channels = Mlp(dim) #MLP on channels
self.layerscale_1 = nn.Parameter(layerscale_init * torch.ones((dim))) # LayerScale
self.layerscale_2 = nn.Parameter(layerscale_init * torch.ones((dim))) # parameters
def forward(self, x):
res_1 = self.linear_patches(self.affine_1(x).transpose(1,2)).transpose(1,2)
x = x + self.layerscale_1 * res_1
res_2 = self.mlp_channels(self.affine_2(x))
x = x + self.layerscale_2 * res_2
return x
ResMLP — rewritten¶
Code below is the result of first rewriting:
combination [transpose -> linear -> transpose back] got nicely packed into a single
EinMix
(mix_patches
)
Mix('b t c -> b t0 c', weight_shape='t t0', bias_shape='t0', t=nb_patches, t0=nb_patches)
- pattern
'b t c -> b t0 c'
tells thatb
andc
are unperturbed, while tokenst->t0
were mixed - explicit parameter shapes are also quite insightful
- pattern
In new implementation affine layer is also handled by
EinMix
:
Mix('b t c -> b t c', weight_shape='c', bias_shape='c', c=dim)
- from the pattern you can see that there is no mixing at all, only multiplication and shift
- multiplication and shift are defined by weight and bias - and those depend only on a channel
- thus affine transform is per-channel
Linear layer is also handled by EinMix, the only difference compared to affine layer is absence of bias
We specified that input is 3d and order is
btc
, nottbc
- this is not written explicitly in the original code
The only step back that we had to do is change an initialization schema for EinMix for affine and linear layers
def Mlp(dim):
return nn.Sequential(
nn.Linear(dim, 4 * dim),
nn.GELU(),
nn.Linear(4 * dim, dim),
)
def init(Mix_layer, scale=1.):
Mix_layer.weight.data[:] = scale
if Mix_layer.bias is not None:
Mix_layer.bias.data[:] = 0
return Mix_layer
class ResMLP_Blocks2(nn.Module):
def __init__(self, nb_patches, dim, layerscale_init):
super().__init__()
self.affine1 = init(Mix('b t c -> b t c', weight_shape='c', bias_shape='c', c=dim))
self.affine2 = init(Mix('b t c -> b t c', weight_shape='c', bias_shape='c', c=dim))
self.mix_patches = Mix('b t c -> b t0 c', weight_shape='t t0', bias_shape='t0', t=nb_patches, t0=nb_patches)
self.mlp_channels = Mlp(dim)
self.linear1 = init(Mix('b t c -> b t c', weight_shape='c', c=dim), scale=layerscale_init)
self.linear2 = init(Mix('b t c -> b t c', weight_shape='c', c=dim), scale=layerscale_init)
def forward(self, x):
res1 = self.mix_patches(self.affine1(x))
x = x + self.linear1(res1)
res2 = self.mlp_channels(self.affine2(x))
x = x + self.linear2(res2)
return x
ResMLP — rewritten more¶
Since here in einops-land we care about code being easy to follow, let's make one more transformation.
We group layers from both branches, and now the order of operations matches the order as they are written in the code.
Could we go further? Actually, yes - nn.Linear
layers can also be replaced by EinMix,
however they are very organic here since first and last operations in branch_channels
show components.
Brevity of nn.Linear
is benefitial when the context specifies tensor shapes.
Other interesing observations:
- hard to notice in the original code
nn.Linear
is preceded by a linear layer (thus latter is redundant or can be fused in the former) - hard to notice in the original code second
nn.Linear
is followed by an affine layer (thus latter is again redundant)
Take time to reorganize your code. This may be quite insightful.
def init(layer: Mix, scale=1.):
layer.weight.data[:] = scale
if layer.bias is not None:
layer.bias.data[:] = 0
return layer
class ResMLP_Blocks3(nn.Module):
def __init__(self, nb_patches, dim, layerscale_init):
super().__init__()
self.branch_patches = nn.Sequential(
init(Mix('b t c -> b t c', weight_shape='c', c=dim), scale=layerscale_init),
Mix('b t c -> b t0 c', weight_shape='t t0', bias_shape='t0', t=nb_patches, t0=nb_patches),
init(Mix('b t c -> b t c', weight_shape='c', bias_shape='c', c=dim)),
)
self.branch_channels = nn.Sequential(
init(Mix('b t c -> b t c', weight_shape='c', c=dim), scale=layerscale_init),
nn.Linear(dim, 4 * dim),
nn.GELU(),
nn.Linear(4 * dim, dim),
init(Mix('b t c -> b t c', weight_shape='c', bias_shape='c', c=dim)),
)
def forward(self, x):
x = x + self.branch_patches(x)
x = x + self.branch_channels(x)
return x
ResMLP — performance¶
There is some fear of using einsum because historically it lagged in performance.
Below we run a test and verify that performace didn't change after transition to EinMix
x = torch.zeros([32, 128, 128])
for layer in [
ResMLP_Blocks(128, dim=128, layerscale_init=1.),
ResMLP_Blocks2(128, dim=128, layerscale_init=1.),
ResMLP_Blocks3(128, dim=128, layerscale_init=1.),
# scripted versions
torch.jit.script(ResMLP_Blocks(128, dim=128, layerscale_init=1.)),
torch.jit.script(ResMLP_Blocks2(128, dim=128, layerscale_init=1.)),
torch.jit.script(ResMLP_Blocks3(128, dim=128, layerscale_init=1.)),
]:
%timeit -n 10 y = layer(x)
28.1 ms ± 1.61 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) 26.3 ms ± 620 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 25.9 ms ± 706 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 26.8 ms ± 2.99 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) 25.9 ms ± 794 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 25.6 ms ± 723 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
TokenMixer from MLPMixer — original code¶
Let's now delve into MLPMixer. We start from pytorch implementation by Jake Tae.
We'll focus on two components of MLPMixer that don't exist in convnets. First component is TokenMixer:
from torch.nn import functional as F
class MLP(nn.Module):
def __init__(self, num_features, expansion_factor, dropout):
super().__init__()
num_hidden = num_features * expansion_factor
self.fc1 = nn.Linear(num_features, num_hidden)
self.dropout1 = nn.Dropout(dropout)
self.fc2 = nn.Linear(num_hidden, num_features)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x):
x = self.dropout1(F.gelu(self.fc1(x)))
x = self.dropout2(self.fc2(x))
return x
class TokenMixer(nn.Module):
def __init__(self, num_features, num_patches, expansion_factor, dropout):
super().__init__()
self.norm = nn.LayerNorm(num_features)
self.mlp = MLP(num_patches, expansion_factor, dropout)
def forward(self, x):
# x.shape == (batch_size, num_patches, num_features)
residual = x
x = self.norm(x)
x = x.transpose(1, 2)
# x.shape == (batch_size, num_features, num_patches)
x = self.mlp(x)
x = x.transpose(1, 2)
# x.shape == (batch_size, num_patches, num_features)
out = x + residual
return out
TokenMixer from MLPMixer — reimplemented¶
We can significantly reduce amount of code by using EinMix
.
- Main caveat addressed by original code is that
nn.Linear
mixes only last axis.EinMix
can mix any axis. - Sequential structure is always preferred as it is easier to follow
- Intentionally there is no residual connection in
TokenMixer
, because honestly it's not work of Mixer and should be done by caller
def TokenMixer(num_features: int, n_patches: int, expansion_factor: int, dropout: float):
n_hidden = n_patches * expansion_factor
return nn.Sequential(
nn.LayerNorm(num_features),
Mix('b hw c -> b hid c', weight_shape='hw hid', bias_shape='hid', hw=n_patches, hidden=n_hidden),
nn.GELU(),
nn.Dropout(dropout),
Mix('b hid c -> b hw c', weight_shape='hid hw', bias_shape='hw', hw=n_patches, hidden=n_hidden),
nn.Dropout(dropout),
)
You may also like independent implementation of MLPMixer from Phil Wang.
Phil solves the issue by repurposing nn.Conv1d
to mix on the second dimension. Hacky, but does the job
MLPMixer's patch embeddings — original¶
Second interesting part of MLPMixer is derived from vision transformers.
In the very beginning an image is split into patches, and each patch is linearly projected into embedding.
I've taken the part of Jake's code responsible for embedding patches:
def check_sizes(image_size, patch_size):
sqrt_num_patches, remainder = divmod(image_size, patch_size)
assert remainder == 0, "`image_size` must be divisibe by `patch_size`"
num_patches = sqrt_num_patches ** 2
return num_patches
class Patcher(nn.Module):
def __init__(
self,
image_size=256,
patch_size=16,
in_channels=3,
num_features=128,
):
_num_patches = check_sizes(image_size, patch_size)
super().__init__()
# per-patch fully-connected is equivalent to strided conv2d
self.patcher = nn.Conv2d(
in_channels, num_features, kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
patches = self.patcher(x)
batch_size, num_features, _, _ = patches.shape
patches = patches.permute(0, 2, 3, 1)
patches = patches.view(batch_size, -1, num_features)
return patches
MLPMixer's patch embeddings — reimplemented¶
EinMix
does this in a single operation. This may require some training at first to understand.
Let's go step-by-step:
b c_in (h hp) (w wp) ->
- 4-dimensional input tensor (BCHW-ordered) is split into patches of shapehp x wp
weight_shape='c_in hp wp c'
. Axesc_in
,hp
andwp
are all absent in the output: three dimensional patch tensor was mixed to produce a vector of lengthc
-> b (h w) c
- output is 3-dimensional. All patches were reorganized fromh x w
grid to one-dimensional sequence of vectors
We don't need to provide image_size beforehead, new implementation handles images of different dimensions as long as they can be divided into patches
def patcher(patch_size=16, in_channels=3, num_features=128):
return Mix('b c_in (h hp) (w wp) -> b (h w) c', weight_shape='c_in hp wp c', bias_shape='c',
c=num_features, hp=patch_size, wp=patch_size, c_in=in_channels)
Vision Permutator¶
As a third example we consider pytorch-like code from ViP paper.
Vision permutator is only slightly more nuanced than previous models, because
- it operates on spatial dimensions separately, while MLPMixer and its friends just pack all spatial info into one axis.
- it splits channels into groups called 'segments'
Paper provides pseudo-code, so I reworked that to complete module with minimal changes. Enjoy:
class WeightedPermuteMLP(nn.Module):
def __init__(self, H, W, C, S):
super().__init__()
self.proj_h = nn.Linear(H * S, H * S)
self.proj_w = nn.Linear(W * S, W * S)
self.proj_c = nn.Linear(C, C)
self.proj = nn.Linear(C, C)
self.S = S
def forward(self, x):
B, H, W, C = x.shape
S = self.S
N = C // S
x_h = x.reshape(B, H, W, N, S).permute(0, 3, 2, 1, 4).reshape(B, N, W, H*S)
x_h = self.proj_h(x_h).reshape(B, N, W, H, S).permute(0, 3, 2, 1, 4).reshape(B, H, W, C)
x_w = x.reshape(B, H, W, N, S).permute(0, 1, 3, 2, 4).reshape(B, H, N, W*S)
x_w = self.proj_w(x_w).reshape(B, H, N, W, S).permute(0, 1, 3, 2, 4).reshape(B, H, W, C)
x_c = self.proj_c(x)
x = x_h + x_w + x_c
x = self.proj(x)
return x
That didn't look readable, right?
This code is also very inflexible: code in the paper did not support batch dimension, and multiple changes were necessary to allow batch processing.
This process is fragile and easily can result in virtually uncatchable bugs.
Now good news: each of these long method chains can be replaced with a single EinMix
layer:
class WeightedPermuteMLP_new(nn.Module):
def __init__(self, H, W, C, seg_len):
super().__init__()
assert C % seg_len == 0, f"can't divide {C} into segments of length {seg_len}"
self.mlp_c = Mix('b h w c -> b h w c0', weight_shape='c c0', bias_shape='c0', c=C, c0=C)
self.mlp_h = Mix('b h w (n c) -> b h0 w (n c0)', weight_shape='h c h0 c0', bias_shape='h0 c0',
h=H, h0=H, c=seg_len, c0=seg_len)
self.mlp_w = Mix('b h w (n c) -> b h w0 (n c0)', weight_shape='w c w0 c0', bias_shape='w0 c0',
w=W, w0=W, c=seg_len, c0=seg_len)
self.proj = nn.Linear(C, C)
def forward(self, x):
x = self.mlp_c(x) + self.mlp_h(x) + self.mlp_w(x)
return self.proj(x)
Great, now let's confirm that performance did not deteriorate.
x = torch.zeros([32, 32, 32, 128])
for layer in [
WeightedPermuteMLP(H=32, W=32, C=128, S=4),
WeightedPermuteMLP_new(H=32, W=32, C=128, seg_len=4),
# scripted versions
torch.jit.script(WeightedPermuteMLP(H=32, W=32, C=128, S=4)),
torch.jit.script(WeightedPermuteMLP_new(H=32, W=32, C=128, seg_len=4)),
]:
%timeit -n 10 y = layer(x)
90.5 ms ± 1.22 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) 91.5 ms ± 616 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 91.8 ms ± 626 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) 87.4 ms ± 3.59 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Final remarks¶
EinMix
has an incredible potential:
it helps with MLPs that don't fit into a limited 'mix all in the last axis' paradigm.
However existing research is ... very limited, it does not cover real possibilities of densely connected architectures.
Most of its systematic novelty is "mix along spacial axes too".
But EinMix
provides an astonishing amount of other possibilities!
Groups of mixers¶
You can find two settings compared in the MLPMixer paper (Supplementary A1)
'b hw c -> b hw_out c', weight_shape='hw hw_out'
and
'b hw c -> b hw_out c', weight_shape='c hw hw_out'
While latter makes more sense (why mixing should work similarly for all channels?), the former performs better.
So one more question is reasonable: what if channels are split into groups, and mixing is defined for each group?
'b hw (group c) -> b hw_out (group c)', weight_shape='group hw hw_out'
Implementing such setting without einops is considerably harder.
Mixing within patch on a grid¶
What if you make mixing 'local' in space? Completely doable:
'b c (h hI) (w wI) -> b c (h hO) (w wO)', weight_shape='c hI wI hO wO'
We split tensor into patches of shape hI wI
and mixed things within channel.
Mixing in subgrids¶
Ok, done with local mixing. How to collect information from the whole image?
Well, you can again densely connect all the tokens, but all-to-all connection is too expensive.
TODO need some image here to show sub-grids and information exhange.
Here is EinMix-way: split the image into subgrids (each subgrid has steps h
and w
), and connect densely tokens within each subgrid
'b c (hI h) (wI w) -> b c (hO h) (wO w)', weight_shape='c hI wI hO wO'
Going deeper¶
And that's very top of the iceberg.
- Want to mix part of axis? — No problems!
- ... in a grid-like manner — Supported!
- ... while mixing channels within group? — Welcome!
- In 2d/3d/4d? — Sure!
- I don't use pytorch — EinMix is available for multiple frameworks!
Hopefully this guide helped you to find MLPs more interesting and intriguing. And simpler to experiment with.