Skip to content

einops.pack and einops.unpack

einops.pack

Packs several tensors into one. See einops tutorial for introduction into packing (and how it replaces stack and concatenation).

Parameters:

Name Type Description Default
tensors Sequence[Tensor]

tensors to be packed, can be of different dimensionality

required
pattern str

pattern that is shared for all inputs and output, e.g. "i j * k" or "batch seq *"

required

Returns:

Type Description
Tuple[Tensor, List[Shape]]

(packed_tensor, packed_shapes aka PS)

Example:

>>> from numpy import zeros as Z
>>> inputs = [Z([2, 3, 5]), Z([2, 3, 7, 5]), Z([2, 3, 7, 9, 5])]
>>> packed, ps = pack(inputs, 'i j * k')
>>> packed.shape, ps
((2, 3, 71, 5), [(), (7,), (7, 9)])

In this example, axes were matched to: i=2, j=3, k=5 based on order (first, second, and last). All other axes were 'packed' and concatenated. PS (packed shapes) contains information about axes that were matched to '*' in every input. Resulting tensor has as many elements as all inputs in total.

Packing can be reversed with unpack, which additionally needs PS (packed shapes) to reconstruct order.

>>> inputs_unpacked = unpack(packed, ps, 'i j * k')
>>> [x.shape for x in inputs_unpacked]
[(2, 3, 5), (2, 3, 7, 5), (2, 3, 7, 9, 5)]

Read the tutorial for introduction and application scenarios.

Source code in einops/packing.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def pack(tensors: Sequence[Tensor], pattern: str) -> Tuple[Tensor, List[Shape]]:
    """
    Packs several tensors into one.
    See einops tutorial for introduction into packing (and how it replaces stack and concatenation).

    Parameters:
        tensors: tensors to be packed, can be of different dimensionality
        pattern: pattern that is shared for all inputs and output, e.g. "i j * k" or "batch seq *"

    Returns:
        (packed_tensor, packed_shapes aka PS)

    Example:
    ```python
    >>> from numpy import zeros as Z
    >>> inputs = [Z([2, 3, 5]), Z([2, 3, 7, 5]), Z([2, 3, 7, 9, 5])]
    >>> packed, ps = pack(inputs, 'i j * k')
    >>> packed.shape, ps
    ((2, 3, 71, 5), [(), (7,), (7, 9)])
    ```

    In this example, axes were matched to: i=2, j=3, k=5 based on order (first, second, and last).
    All other axes were 'packed' and concatenated.
    PS (packed shapes) contains information about axes that were matched to '*' in every input.
    Resulting tensor has as many elements as all inputs in total.

    Packing can be reversed with unpack, which additionally needs PS (packed shapes) to reconstruct order.

    ```python
    >>> inputs_unpacked = unpack(packed, ps, 'i j * k')
    >>> [x.shape for x in inputs_unpacked]
    [(2, 3, 5), (2, 3, 7, 5), (2, 3, 7, 9, 5)]
    ```

    Read the tutorial for introduction and application scenarios.
    """
    n_axes_before, n_axes_after, min_axes = analyze_pattern(pattern, "pack")

    # packing zero tensors is illegal
    backend = get_backend(tensors[0])

    reshaped_tensors: List[Tensor] = []
    packed_shapes: List[Shape] = []
    for i, tensor in enumerate(tensors):
        shape = backend.shape(tensor)
        if len(shape) < min_axes:
            raise EinopsError(
                f"packed tensor #{i} (enumeration starts with 0) has shape {shape}, "
                f"while pattern {pattern} assumes at least {min_axes} axes"
            )
        axis_after_packed_axes = len(shape) - n_axes_after
        packed_shapes.append(shape[n_axes_before:axis_after_packed_axes])
        reshaped_tensors.append(backend.reshape(tensor, (*shape[:n_axes_before], -1, *shape[axis_after_packed_axes:])))

    return backend.concat(reshaped_tensors, axis=n_axes_before), packed_shapes


einops.unpack

Unpacks a single tensor into several by splitting over a selected axes. See einops tutorial for introduction into packing (and how it replaces stack and concatenation).

Parameters:

Name Type Description Default
tensor Tensor

tensor to be unpacked

required
packed_shapes List[Shape]

packed_shapes (aka PS) is a list of shapes that take place of '*' in each output. output will contain a single tensor for every provided shape

required
pattern str

pattern that is shared for input and all outputs, e.g. "i j * k" or "batch seq *", where * designates an axis to be unpacked

required

Returns:

Type Description
List[Tensor]

list of tensors

If framework supports views, results are views to the original tensor.

Example:

>>> from numpy import zeros as Z
>>> inputs = [Z([2, 3, 5]), Z([2, 3, 7, 5]), Z([2, 3, 7, 9, 5])]
>>> packed, ps = pack(inputs, 'i j * k')
>>> packed.shape, ps
((2, 3, 71, 5), [(), (7,), (7, 9)])

In this example, axes were matched to: i=2, j=3, k=5 based on order (first, second, and last). All other axes were 'packed' and concatenated. PS (packed shapes) contains information about axes that were matched to '*' in every input. Resulting tensor has as many elements as all inputs in total.

Packing can be reversed with unpack, which additionally needs PS (packed shapes) to reconstruct order.

>>> inputs_unpacked = unpack(packed, ps, 'i j * k')
>>> [x.shape for x in inputs_unpacked]
[(2, 3, 5), (2, 3, 7, 5), (2, 3, 7, 9, 5)]

Read the tutorial for introduction and application scenarios.

Source code in einops/packing.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
def unpack(tensor: Tensor, packed_shapes: List[Shape], pattern: str) -> List[Tensor]:
    """
    Unpacks a single tensor into several by splitting over a selected axes.
    See einops tutorial for introduction into packing (and how it replaces stack and concatenation).

    Parameters:
        tensor: tensor to be unpacked
        packed_shapes: packed_shapes (aka PS) is a list of shapes that take place of '*' in each output.
            output will contain a single tensor for every provided shape
        pattern: pattern that is shared for input and all outputs, e.g. "i j * k" or "batch seq *",
            where * designates an axis to be unpacked

    Returns:
        list of tensors

    If framework supports views, results are views to the original tensor.

    Example:
    ```python
    >>> from numpy import zeros as Z
    >>> inputs = [Z([2, 3, 5]), Z([2, 3, 7, 5]), Z([2, 3, 7, 9, 5])]
    >>> packed, ps = pack(inputs, 'i j * k')
    >>> packed.shape, ps
    ((2, 3, 71, 5), [(), (7,), (7, 9)])
    ```

    In this example, axes were matched to: i=2, j=3, k=5 based on order (first, second, and last).
    All other axes were 'packed' and concatenated.
    PS (packed shapes) contains information about axes that were matched to '*' in every input.
    Resulting tensor has as many elements as all inputs in total.

    Packing can be reversed with unpack, which additionally needs PS (packed shapes) to reconstruct order.

    ```python
    >>> inputs_unpacked = unpack(packed, ps, 'i j * k')
    >>> [x.shape for x in inputs_unpacked]
    [(2, 3, 5), (2, 3, 7, 5), (2, 3, 7, 9, 5)]
    ```

    Read the tutorial for introduction and application scenarios.
    """
    n_axes_before, n_axes_after, min_axes = analyze_pattern(pattern, opname="unpack")

    backend = get_backend(tensor)
    input_shape = backend.shape(tensor)
    if len(input_shape) != n_axes_before + 1 + n_axes_after:
        raise EinopsError(f"unpack(..., {pattern}) received input of wrong dim with shape {input_shape}")

    unpacked_axis: int = n_axes_before

    lengths_of_composed_axes: List[int] = [-1 if -1 in p_shape else prod(p_shape) for p_shape in packed_shapes]

    n_unknown_composed_axes = sum(int(x == -1) for x in lengths_of_composed_axes)
    if n_unknown_composed_axes > 1:
        raise EinopsError(
            f"unpack(..., {pattern}) received more than one -1 in {packed_shapes} and can't infer dimensions"
        )

    # following manipulations allow to skip some shape verifications
    # and leave it to backends

    # [[], [2, 3], [4], [-1, 5], [6]] < examples of packed_axis
    # split positions when computed should be
    # [0,   1,      7,   11,      N-6 , N ], where N = length of axis
    split_positions = [0] * len(packed_shapes) + [input_shape[unpacked_axis]]
    if n_unknown_composed_axes == 0:
        for i, x in enumerate(lengths_of_composed_axes[:-1]):
            split_positions[i + 1] = split_positions[i] + x
    else:
        unknown_composed_axis: int = lengths_of_composed_axes.index(-1)
        for i in range(unknown_composed_axis):
            split_positions[i + 1] = split_positions[i] + lengths_of_composed_axes[i]
        for j in range(unknown_composed_axis + 1, len(lengths_of_composed_axes))[::-1]:
            split_positions[j] = split_positions[j + 1] - lengths_of_composed_axes[j]

    shape_start = input_shape[:unpacked_axis]
    shape_end = input_shape[unpacked_axis + 1 :]
    slice_filler = (slice(None, None),) * unpacked_axis
    try:
        return [
            backend.reshape(
                # shortest way slice arbitrary axis
                tensor[(*slice_filler, slice(split_positions[i], split_positions[i + 1]))],
                (*shape_start, *element_shape, *shape_end),
            )
            for i, element_shape in enumerate(packed_shapes)
        ]
    except Exception:
        # this hits if there is an error during reshapes, which means passed shapes were incorrect
        raise RuntimeError(
            f'Error during unpack(..., "{pattern}"): could not split axis of size {split_positions[-1]}'
            f" into requested {packed_shapes}"
        )