EinMix: universal toolkit for advanced MLP architectures¶
Recent progress in MLP-based architectures demonstrated that very specific MLPs can compete with convnets and transformers (and even outperform them).
EinMix allows writing such architectures in a more uniform and readable way.
EinMix — building block of MLPs¶
from einops.layers.torch import EinMix as Mix
# tutorial uses torch. EinMix is available for other frameworks too
from torch import nn
from torch.nn import functional as F
Logic of EinMix is very close to the one of einsum
.
If you're not familiar with einsum, follow these guides first:
- https://rockt.github.io/2018/04/30/einsum
- https://towardsdatascience.com/einsum-an-underestimated-function-99ca96e2942e
Einsum uniformly describes a number of operations.
EinMix
is a layer (not function) implementing a similar logic, it has some differences with einsum
.
Let's implement simple linear layer using einsum
weight = <...create and initialize parameter...>
bias = <...create and initialize parameter...>
result = torch.einsum('tbc,cd->tbd', embeddings, weight) + bias
EinMix counter-part is:
mix_channels = Mix('t b c -> t b c_out', weight_shape='c c_out', bias_shape='c_out', ...)
result = mix_channels(embeddings)
Main differences compared to plain einsum
are:
- layer takes care of the parameter initialization & management
- weight is not in the comprehension
- EinMix includes bias term
We'll discuss other changes a bit later, now let's implement some elements from MLPMixer.
TokenMixer from MLPMixer — original code¶
We start from pytorch implementation of MLPMixer by Jake Tae.
We'll focus on two components of MLPMixer that don't exist in convnets. First component is TokenMixer:
class MLP(nn.Module):
def __init__(self, num_features, expansion_factor, dropout):
super().__init__()
num_hidden = num_features * expansion_factor
self.fc1 = nn.Linear(num_features, num_hidden)
self.dropout1 = nn.Dropout(dropout)
self.fc2 = nn.Linear(num_hidden, num_features)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x):
x = self.dropout1(F.gelu(self.fc1(x)))
x = self.dropout2(self.fc2(x))
return x
class TokenMixer(nn.Module):
def __init__(self, num_features, num_patches, expansion_factor, dropout):
super().__init__()
self.norm = nn.LayerNorm(num_features)
self.mlp = MLP(num_patches, expansion_factor, dropout)
def forward(self, x):
# x.shape == (batch_size, num_patches, num_features)
residual = x
x = self.norm(x)
x = x.transpose(1, 2)
# x.shape == (batch_size, num_features, num_patches)
x = self.mlp(x)
x = x.transpose(1, 2)
# x.shape == (batch_size, num_patches, num_features)
out = x + residual
return out
TokenMixer from MLPMixer — reimplemented¶
We can significantly reduce amount of code by using EinMix
.
- Main caveat addressed by original code is that
nn.Linear
mixes only last axis.EinMix
can mix any axis (or set of axes). - Sequential structure is always preferred as it is easier to follow
- Intentionally there is no residual connection in
TokenMixer
, because honestly it's not work of Mixer and should be done by caller
def TokenMixer(num_features: int, n_patches: int, expansion_factor: int, dropout: float):
n_hidden = n_patches * expansion_factor
return nn.Sequential(
nn.LayerNorm(num_features),
Mix("b hw c -> b hid c", weight_shape="hw hid", bias_shape="hid", hw=n_patches, hidden=n_hidden),
nn.GELU(),
nn.Dropout(dropout),
Mix("b hid c -> b hw c", weight_shape="hid hw", bias_shape="hw", hw=n_patches, hidden=n_hidden),
nn.Dropout(dropout),
)
You may also check another implementation of MLPMixer from Phil Wang.
Phil solves the issue by repurposing nn.Conv1d
to mix on the second dimension. Hacky, but does the job
MLPMixer's patch embeddings (aka ViT patch embeddings) — original¶
Second interesting part of MLPMixer is derived from vision transformers.
In the very beginning an image is split into patches, and each patch is linearly projected into embedding:
def check_sizes(image_size, patch_size):
sqrt_num_patches, remainder = divmod(image_size, patch_size)
assert remainder == 0, "`image_size` must be divisibe by `patch_size`"
num_patches = sqrt_num_patches ** 2
return num_patches
class Patcher(nn.Module):
def __init__(
self,
image_size=256,
patch_size=16,
in_channels=3,
num_features=128,
):
_num_patches = check_sizes(image_size, patch_size)
super().__init__()
# per-patch fully-connected is equivalent to strided conv2d
self.patcher = nn.Conv2d(
in_channels, num_features, kernel_size=patch_size, stride=patch_size
)
def forward(self, x):
patches = self.patcher(x)
batch_size, num_features, _, _ = patches.shape
patches = patches.permute(0, 2, 3, 1)
patches = patches.view(batch_size, -1, num_features)
return patches
MLPMixer's patch embeddings — reimplemented¶
EinMix
does this in a single operation. This may require some training at first to understand.
Let's go step-by-step:
b c_in (h hp) (w wp) ->
- 4-dimensional input tensor (BCHW-ordered) is split into patches of shapehp x wp
weight_shape='c_in hp wp c'
. Axesc_in
,hp
andwp
are all absent in the output: three dimensional patch tensor was mixed to produce a vector of lengthc
-> b (h w) c
- output is 3-dimensional. All patches were reorganized fromh x w
grid to one-dimensional sequence of vectors
We don't need to provide image_size beforehead, new implementation handles images of different dimensions as long as they can be divided into patches
def patcher(patch_size=16, in_channels=3, num_features=128):
return Mix("b c_in (h hp) (w wp) -> b (h w) c", weight_shape="c_in hp wp c", bias_shape="c",
c=num_features, hp=patch_size, wp=patch_size, c_in=in_channels)
Vision Permutator¶
As a third example we consider pytorch-like code from ViP paper.
Vision permutator is only slightly more nuanced than previous models, because
- it operates on spatial dimensions separately, while MLPMixer and its friends just pack all spatial info into one axis.
- it splits channels into groups called 'segments'
class WeightedPermuteMLP(nn.Module):
def __init__(self, H, W, C, S):
super().__init__()
self.proj_h = nn.Linear(H * S, H * S)
self.proj_w = nn.Linear(W * S, W * S)
self.proj_c = nn.Linear(C, C)
self.proj = nn.Linear(C, C)
self.S = S
def forward(self, x):
B, H, W, C = x.shape
S = self.S
N = C // S
x_h = x.reshape(B, H, W, N, S).permute(0, 3, 2, 1, 4).reshape(B, N, W, H*S)
x_h = self.proj_h(x_h).reshape(B, N, W, H, S).permute(0, 3, 2, 1, 4).reshape(B, H, W, C)
x_w = x.reshape(B, H, W, N, S).permute(0, 1, 3, 2, 4).reshape(B, H, N, W*S)
x_w = self.proj_w(x_w).reshape(B, H, N, W, S).permute(0, 1, 3, 2, 4).reshape(B, H, W, C)
x_c = self.proj_c(x)
x = x_h + x_w + x_c
x = self.proj(x)
return x
That didn't look readable, right?
This code is also very inflexible: code in the paper did not support batch dimension, and multiple changes were necessary to allow batch processing.
This process is fragile and easily can result in virtually uncatchable bugs.
Now good news: each of these long method chains can be replaced with a single EinMix
layer:
class WeightedPermuteMLP_new(nn.Module):
def __init__(self, H, W, C, seg_len):
super().__init__()
assert C % seg_len == 0, f"can't divide {C} into segments of length {seg_len}"
self.mlp_c = Mix("b h w c -> b h w c0", weight_shape="c c0", bias_shape="c0",
c=C, c0=C)
self.mlp_h = Mix("b h w (n c) -> b h0 w (n c0)", weight_shape="h c h0 c0", bias_shape="h0 c0",
h=H, h0=H, c=seg_len, c0=seg_len)
self.mlp_w = Mix("b h w (n c) -> b h w0 (n c0)", weight_shape="w c w0 c0", bias_shape="w0 c0",
w=W, w0=W, c=seg_len, c0=seg_len)
self.proj = nn.Linear(C, C)
def forward(self, x):
x = self.mlp_c(x) + self.mlp_h(x) + self.mlp_w(x)
return self.proj(x)
Multi-head attention, once again¶
EinMix can be (mis)used to compute multiple projections and perform transpositions along the way.
For example, F.scaled_dot_product_attention wants a specific order of axes, and an explicit head axis. We can combine linear projection with providing desired order of arguments in a single operation.
class MultiheadAttention(nn.Module):
def __init__(self, dim_input, n_heads, head_dim):
super().__init__()
self.input_to_qkv = Mix("b t c -> qkv b h t hid", "c qkv h hid",
c=dim_input, qkv=3, h=n_heads, hid=head_dim)
self.out_proj = Mix("b h t hid -> b t c", "h hid c",
h=n_heads, hid=head_dim, c=dim_input)
def forward(self, x):
q, k, v = self.input_to_qkv(x) # fused projections, computed in one go
return self.out_proj(F.scaled_dot_product_attention(q, k, v)) # flash attention
Exercises¶
Many normalizations (batch norm, layer norm, etc) use affine scaling afterwards. Implement this scaling using
EinMix
.let's assume you have an input tensor of shape
[b, t, n_groups, n_channels]
, and you want to apply a separate linear layer to every group of channels.This will introduce
n_groups
matrices of shape[n_channels, n_channels]
andn_groups
biases of shape[n_channels]
. Can you perform this operation with just oneEinMix
?
Final remarks¶
EinMix
helps with MLPs that don't fit into a limited 'mix all in the last axis' paradigm, and specially helpful for non-1d inputs (images, videos, etc).
However existing research does not cover real possibilities of densely connected architectures.
Most of its systematic novelty is "mix along spatial axes actually works".
But EinMix
provides an astonishing amount of other possibilities!. Let me mention some examples:
Mixing within a patch on a grid¶
What if you make mixing 'local' in space? Completely doable:
'b c (h hI) (w wI) -> b c (h hO) (w wO)', weight_shape='c hI wI hO wO'
We split tensor into patches of shape hI wI
and mixed per-channel.
Mixing in subgrids¶
Opposite question: how to collect information from the whole image (without attention)?
Well, you can again densely connect all the tokens, but all-to-all connection can be too expensive.
Here is EinMix-way: split the image into subgrids (each subgrid has steps h
and w
), and connect densely tokens within each subgrid
'b c (hI h) (wI w) -> b c (hO h) (wO w)', weight_shape='c hI wI hO wO'
Going deeper¶
And that's very top of the iceberg.
- Want to mix part of axis? — No problems!
- ... in a grid-like manner — Supported!
- ... while mixing channels within group? — Welcome!
- In 2d/3d/4d? — Sure!
- Don't use pytorch? — EinMix is available for multiple frameworks!
Hopefully this guide helped you to find MLPs more interesting and intriguing. And simpler to experiment with.