Skip to content

einops.pack and einops.unpack


Packs several tensors into one. See einops tutorial for introduction into packing (and how it replaces stack and concatenation).


Name Type Description Default
tensors Sequence[~Tensor]

tensors to be packed, can be of different dimensionality

pattern str

pattern that is shared for all inputs and output, e.g. "i j * k" or "batch seq *"



Type Description
Tuple[~Tensor, List[Union[Tuple[int, ...], List[int]]]]

(packed_tensor, packed_shapes aka PS)


>>> from numpy import zeros as Z
>>> inputs = [Z([2, 3, 5]), Z([2, 3, 7, 5]), Z([2, 3, 7, 9, 5])]
>>> packed, ps = pack(inputs, 'i j * k')
>>> packed.shape, ps
((2, 3, 71, 5), [(), (7,), (7, 9)])

In this example, axes were matched to: i=2, j=3, k=5 based on order (first, second, and last). All other axes were 'packed' and concatenated. PS (packed shapes) contains information about axes that were matched to '*' in every input. Resulting tensor has as many elements as all inputs in total.

Packing can be reversed with unpack, which additionally needs PS (packed shapes) to reconstruct order.

>>> inputs_unpacked = unpack(packed, ps, 'i j * k')
>>> [x.shape for x in inputs_unpacked]
[(2, 3, 5), (2, 3, 7, 5), (2, 3, 7, 9, 5)]

Read the tutorial for introduction and application scenarios.

Source code in einops/
def pack(tensors: Sequence[Tensor], pattern: str) -> Tuple[Tensor, List[Shape]]:
    Packs several tensors into one.
    See einops tutorial for introduction into packing (and how it replaces stack and concatenation).

        tensors: tensors to be packed, can be of different dimensionality
        pattern: pattern that is shared for all inputs and output, e.g. "i j * k" or "batch seq *"

        (packed_tensor, packed_shapes aka PS)

    >>> from numpy import zeros as Z
    >>> inputs = [Z([2, 3, 5]), Z([2, 3, 7, 5]), Z([2, 3, 7, 9, 5])]
    >>> packed, ps = pack(inputs, 'i j * k')
    >>> packed.shape, ps
    ((2, 3, 71, 5), [(), (7,), (7, 9)])

    In this example, axes were matched to: i=2, j=3, k=5 based on order (first, second, and last).
    All other axes were 'packed' and concatenated.
    PS (packed shapes) contains information about axes that were matched to '*' in every input.
    Resulting tensor has as many elements as all inputs in total.

    Packing can be reversed with unpack, which additionally needs PS (packed shapes) to reconstruct order.

    >>> inputs_unpacked = unpack(packed, ps, 'i j * k')
    >>> [x.shape for x in inputs_unpacked]
    [(2, 3, 5), (2, 3, 7, 5), (2, 3, 7, 9, 5)]

    Read the tutorial for introduction and application scenarios.
    n_axes_before, n_axes_after, min_axes = analyze_pattern(pattern, 'pack')

    # packing zero tensors is illegal
    backend = get_backend(tensors[0])

    reshaped_tensors: List[Tensor] = []
    packed_shapes: List[Shape] = []
    for i, tensor in enumerate(tensors):
        shape = backend.shape(tensor)
        if len(shape) < min_axes:
            raise EinopsError(f'packed tensor #{i} (enumeration starts with 0) has shape {shape}, '
                              f'while pattern {pattern} assumes at least {min_axes} axes')
        axis_after_packed_axes = len(shape) - n_axes_after
            backend.reshape(tensor, (*shape[:n_axes_before], -1, *shape[axis_after_packed_axes:]))

    return backend.concat(reshaped_tensors, axis=n_axes_before), packed_shapes


Unpacks a single tensor into several by splitting over a selected axes. See einops tutorial for introduction into packing (and how it replaces stack and concatenation).


Name Type Description Default
tensor ~Tensor

tensor to be unpacked

packed_shapes List[Union[Tuple[int, ...], List[int]]]

packed_shapes (aka PS) is a list of shapes that take place of '*' in each output. output will contain a single tensor for every provided shape

pattern str

pattern that is shared for input and all outputs, e.g. "i j * k" or "batch seq *", where * designates an axis to be unpacked



Type Description

list of tensors

If framework supports views, results are views to the original tensor.


>>> from numpy import zeros as Z
>>> inputs = [Z([2, 3, 5]), Z([2, 3, 7, 5]), Z([2, 3, 7, 9, 5])]
>>> packed, ps = pack(inputs, 'i j * k')
>>> packed.shape, ps
((2, 3, 71, 5), [(), (7,), (7, 9)])

In this example, axes were matched to: i=2, j=3, k=5 based on order (first, second, and last). All other axes were 'packed' and concatenated. PS (packed shapes) contains information about axes that were matched to '*' in every input. Resulting tensor has as many elements as all inputs in total.

Packing can be reversed with unpack, which additionally needs PS (packed shapes) to reconstruct order.

>>> inputs_unpacked = unpack(packed, ps, 'i j * k')
>>> [x.shape for x in inputs_unpacked]
[(2, 3, 5), (2, 3, 7, 5), (2, 3, 7, 9, 5)]

Read the tutorial for introduction and application scenarios.

Source code in einops/
def unpack(tensor: Tensor, packed_shapes: List[Shape], pattern: str) -> List[Tensor]:
    Unpacks a single tensor into several by splitting over a selected axes.
    See einops tutorial for introduction into packing (and how it replaces stack and concatenation).

        tensor: tensor to be unpacked
        packed_shapes: packed_shapes (aka PS) is a list of shapes that take place of '*' in each output.
            output will contain a single tensor for every provided shape
        pattern: pattern that is shared for input and all outputs, e.g. "i j * k" or "batch seq *",
            where * designates an axis to be unpacked

        list of tensors

    If framework supports views, results are views to the original tensor.

    >>> from numpy import zeros as Z
    >>> inputs = [Z([2, 3, 5]), Z([2, 3, 7, 5]), Z([2, 3, 7, 9, 5])]
    >>> packed, ps = pack(inputs, 'i j * k')
    >>> packed.shape, ps
    ((2, 3, 71, 5), [(), (7,), (7, 9)])

    In this example, axes were matched to: i=2, j=3, k=5 based on order (first, second, and last).
    All other axes were 'packed' and concatenated.
    PS (packed shapes) contains information about axes that were matched to '*' in every input.
    Resulting tensor has as many elements as all inputs in total.

    Packing can be reversed with unpack, which additionally needs PS (packed shapes) to reconstruct order.

    >>> inputs_unpacked = unpack(packed, ps, 'i j * k')
    >>> [x.shape for x in inputs_unpacked]
    [(2, 3, 5), (2, 3, 7, 5), (2, 3, 7, 9, 5)]

    Read the tutorial for introduction and application scenarios.
    n_axes_before, n_axes_after, min_axes = analyze_pattern(pattern, opname='unpack')

    backend = get_backend(tensor)
    input_shape = backend.shape(tensor)
    if len(input_shape) != n_axes_before + 1 + n_axes_after:
        raise EinopsError(f'unpack(..., {pattern}) received input of wrong dim with shape {input_shape}')

    unpacked_axis: int = n_axes_before

    lengths_of_composed_axes: List[int] = [
        -1 if -1 in p_shape else prod(p_shape)
        for p_shape in packed_shapes

    n_unknown_composed_axes = sum(int(x == -1) for x in lengths_of_composed_axes)
    if n_unknown_composed_axes > 1:
        raise EinopsError(
            f"unpack(..., {pattern}) received more than one -1 in {packed_shapes} and can't infer dimensions"

    # following manipulations allow to skip some shape verifications
    # and leave it to backends

    # [[], [2, 3], [4], [-1, 5], [6]] < examples of packed_axis
    # split positions when computed should be
    # [0,   1,      7,   11,      N-6 , N ], where N = length of axis
    split_positions = [0] * len(packed_shapes) + [input_shape[unpacked_axis]]
    if n_unknown_composed_axes == 0:
        for i, x in enumerate(lengths_of_composed_axes[:-1]):
            split_positions[i + 1] = split_positions[i] + x
        unknown_composed_axis: int = lengths_of_composed_axes.index(-1)
        for i in range(unknown_composed_axis):
            split_positions[i + 1] = split_positions[i] + lengths_of_composed_axes[i]
        for j in range(unknown_composed_axis + 1, len(lengths_of_composed_axes))[::-1]:
            split_positions[j] = split_positions[j + 1] - lengths_of_composed_axes[j]

    shape_start = input_shape[:unpacked_axis]
    shape_end = input_shape[unpacked_axis + 1:]
    slice_filler = (slice(None, None),) * unpacked_axis
        return [
                # shortest way slice arbitrary axis
                tensor[(*slice_filler, slice(split_positions[i], split_positions[i + 1]))],
                (*shape_start, *element_shape, *shape_end)
            for i, element_shape in enumerate(packed_shapes)
    except BaseException:
        # this hits if there is an error during reshapes, which means passed shapes were incorrect
        raise RuntimeError(f'Error during unpack(..., "{pattern}"): could not split axis of size {split_positions[-1]}'
                           f' into requested {packed_shapes}')