einops.pack and einops.unpack¶
einops 0.6 introduces two more functions to the family: pack
and unpack
.
Here is what they do:
unpack
reversespack
pack
reversesunpack
Enlightened with this exhaustive description, let's move to examples.
# we'll use numpy for demo purposes
# operations work the same way with other frameworks
import numpy as np
Stacking data layers¶
Assume we have RGB image along with a corresponding depth image that we want to stack:
from einops import pack, unpack
h, w = 100, 200
# image_rgb is 3-dimensional (h, w, 3) and depth is 2-dimensional (h, w)
image_rgb = np.random.random([h, w, 3])
image_depth = np.random.random([h, w])
# but we can stack them
image_rgbd, ps = pack([image_rgb, image_depth], 'h w *')
How to read packing patterns¶
pattern h w *
means that
- output is 3-dimensional
- first two axes (
h
andw
) are shared across all inputs and also shared with output - inputs, however do not have to be 3-dimensional. They can be 2-dim, 3-dim, 4-dim, etc.
Regardless of inputs dimensionality, they all will be packed into 3-dim output, and information about how they were packed is stored inPS
# as you see, pack properly appended depth as one more layer
# and correctly aligned axes!
# this won't work off the shelf with np.concatenate or torch.cat or alike
image_rgb.shape, image_depth.shape, image_rgbd.shape
((100, 200, 3), (100, 200), (100, 200, 4))
# now let's see what PS keeps.
# PS means Packed Shapes, not PlayStation or Post Script
ps
[(3,), ()]
which reads: first tensor had shape h, w, *and 3*
, while second tensor had shape h, w *and nothing more*
.
That's just enough to reverse packing:
# remove 1-axis in depth image during unpacking. Results are (h, w, 3) and (h, w)
unpacked_rgb, unpacked_depth = unpack(image_rgbd, ps, 'h w *')
unpacked_rgb.shape, unpacked_depth.shape
((100, 200, 3), (100, 200))
we can unpack tensor in different ways manually:
# simple unpack by splitting the axis. Results are (h, w, 3) and (h, w, 1)
rgb, depth = unpack(image_rgbd, [[3], [1]], 'h w *')
# different split, both outputs have shape (h, w, 2)
rg, bd = unpack(image_rgbd, [[2], [2]], 'h w *')
# unpack to 4 tensors of shape (h, w). More like 'unstack over last axis'
[r, g, b, d] = unpack(image_rgbd, [[], [], [], []], 'h w *')
Short summary so far¶
einops.pack
is a 'more generic concatenation' (that can stack too)einops.unpack
is a 'more generic split'
And, of course, einops
functions are more verbose, and reversing concatenation now is dead simple
Compared to other einops
functions, pack
and unpack
have a compact pattern without arrow, and the same pattern can be used in pack
and unpack
. These patterns are very simplistic: just a sequence of space-separated axes names.
One axis is *
, all other axes are valid identifiers.
Now let's discuss some practical cases
Auto-batching¶
ML models by default accept batches: batch of images, or batch of sentences, or batch of audios, etc.
During debugging or inference, however, it is common to pass a single image instead (and thus output should be a single prediction)
In this example we'll write universal_predict
that can handle both cases.
from einops import reduce
def image_classifier(images_bhwc):
# mock for image classifier
predictions = reduce(images_bhwc, 'b h w c -> b c', 'mean', h=100, w=200, c=3)
return predictions
def universal_predict(x):
x_packed, ps = pack([x], '* h w c')
predictions_packed = image_classifier(x_packed)
[predictions] = unpack(predictions_packed, ps, '* cls')
return predictions
# works with a single image
print(universal_predict(np.zeros([h, w, 3])).shape)
# works with a batch of images
batch = 5
print(universal_predict(np.zeros([batch, h, w, 3])).shape)
# or even a batch of videos
n_frames = 7
print(universal_predict(np.zeros([batch, n_frames, h, w, 3])).shape)
(3,) (5, 3) (5, 7, 3)
what we can learn from this example:
pack
andunpack
play nicely together. That's not a coincidence :)- patterns in
pack
andunpack
may differ, and that's quite common for applications - unlike other operations in
einops
,(un)pack
does not provide arbitrary reordering of axes
Class token in VIT¶
Let's assume we have a simple transformer model that works with BTC
-shaped tensors.
def transformer_mock(x_btc):
# imagine this is a transformer model, a very efficient one
assert len(x_btc.shape) == 3
return x_btc
Let's implement vision transformer (ViT) with a class token (i.e. static token, corresponding output is used to classify an image)
# below it is assumed that you already
# 1) split batch of images into patches 2) applied linear projection and 3) used positional embedding.
# We'll skip that here. But hey, here is an einops-style way of doing all of that in a single shot!
# from einops.layers.torch import EinMix
# patcher_and_posembedder = EinMix('b (h h2) (w w2) c -> b h w c_out', weight_shape='h2 w2 c c_out',
# bias_shape='h w c_out', h2=..., w2=...)
# patch_tokens_bhwc = patcher_and_posembedder(images_bhwc)
# preparations
batch, height, width, c = 6, 16, 16, 256
patch_tokens = np.random.random([batch, height, width, c])
class_tokens = np.zeros([batch, c])
def vit_einops(class_tokens, patch_tokens):
input_packed, ps = pack([class_tokens, patch_tokens], 'b * c')
output_packed = transformer_mock(input_packed)
return unpack(output_packed, ps, 'b * c_out')
class_token_emb, patch_tokens_emb = vit_einops(class_tokens, patch_tokens)
class_token_emb.shape, patch_tokens_emb.shape
((6, 256), (6, 16, 16, 256))
At this point, let's make a small pause and understand conveniences of this pipeline, by contrasting it to more 'standard' code
def vit_vanilla(class_tokens, patch_tokens):
b, h, w, c = patch_tokens.shape
class_tokens_b1c = class_tokens[:, np.newaxis, :]
patch_tokens_btc = np.reshape(patch_tokens, [b, -1, c])
input_packed = np.concatenate([class_tokens_b1c, patch_tokens_btc], axis=1)
output_packed = transformer_mock(input_packed)
class_token_emb = np.squeeze(output_packed[:, :1, :], 1)
patch_tokens_emb = np.reshape(output_packed[:, 1:, :], [b, h, w, -1])
return class_token_emb, patch_tokens_emb
class_token_emb2, patch_tokens_emb2 = vit_vanilla(class_tokens, patch_tokens)
assert np.allclose(class_token_emb, class_token_emb2)
assert np.allclose(patch_tokens_emb, patch_tokens_emb2)
Notably, we have put all packing and unpacking, reshapes, adding and removing of dummy axes into a couple of lines.
Packing different modalities together¶
We can extend the previous example: it is quite common to mix elements of different types of inputs in transformers.
The simples one is to mix tokens from all inputs:
all_inputs = [text_tokens_btc, image_bhwc, task_token_bc, static_tokens_bnc]
inputs_packed, ps = pack(all_inputs, 'b * c')
and you can unpack
resulting tokens to the same structure.
Packing data coming from different sources together¶
Most notable example is of course GANs:
input_ims, ps = pack([true_images, fake_images], '* h w c')
true_pred, fake_pred = unpack(model(input_ims), ps, '* c')
true_pred
and fake_pred
are handled differently, that's why we separated them
Predicting multiple outputs at the same time¶
It is quite common to pack prediction of multiple target values into a single layer.
This is more efficient, but code is less readable. For example, that's how detection code may look like:
def loss_detection(model_output_bhwc, mask_h: int, mask_w: int, n_classes: int):
output = model_output_bhwc
confidence = output[..., 0].sigmoid()
bbox_x_shift = output[..., 1].sigmoid()
bbox_y_shift = output[..., 2].sigmoid()
bbox_w = output[..., 3]
bbox_h = output[..., 4]
mask_logits = output[..., 5: 5 + mask_h * mask_w]
mask_logits = mask_logits.reshape([*mask_logits.shape[:-1], mask_h, mask_w])
class_logits = output[..., 5 + mask_h * mask_w:]
assert class_logits.shape[-1] == n_classes, class_logits.shape[-1]
# downstream computations
return confidence, bbox_x_shift, bbox_y_shift, bbox_h, bbox_w, mask_logits, class_logits
When the same logic is implemented in einops, there is no need to memorize offsets.
Additionally, reshapes and shape checks are automatic:
def loss_detection_einops(model_output, mask_h: int, mask_w: int, n_classes: int):
confidence, bbox_x_shift, bbox_y_shift, bbox_w, bbox_h, mask_logits, class_logits \
= unpack(model_output, [[]] * 5 + [[mask_h, mask_w], [n_classes]], 'b h w *')
confidence = confidence.sigmoid()
bbox_x_shift = bbox_x_shift.sigmoid()
bbox_y_shift = bbox_y_shift.sigmoid()
# downstream computations
return confidence, bbox_x_shift, bbox_y_shift, bbox_h, bbox_w, mask_logits, class_logits
# check that results are identical
import torch
dims = dict(mask_h=6, mask_w=8, n_classes=19)
model_output = torch.randn([3, 5, 7, 5 + dims['mask_h'] * dims['mask_w'] + dims['n_classes']])
for a, b in zip(loss_detection(model_output, **dims), loss_detection_einops(model_output, **dims)):
assert torch.allclose(a, b)
Or maybe reinforcement learning is closer to your mind?
If so, predicting multiple outputs is valuable there too:
action_logits, reward_expectation, q_values, expected_entropy_after_action = \
unpack(predictions_btc, [[n_actions], [], [n_actions], [n_actions]], 'b step *')
That's all for today!¶
happy packing and unpacking!