Parse a tensor shape to dictionary mapping axes names to their lengths.

# Use underscore to skip the dimension in parsing.
>>> x = np.zeros([2, 3, 5, 7])
>>> parse_shape(x, 'batch _ h w')
{'batch': 2, 'h': 5, 'w': 7}

# `parse_shape` output can be used to specify axes_lengths for other operations:
>>> y = np.zeros([700])
>>> rearrange(y, '(b c h w) -> b c h w', **parse_shape(x, 'b _ h w')).shape
(2, 10, 5, 7)

For symbolic frameworks may return symbols, not integers.


Name Type Description Default

tensor of any supported framework

pattern str

str, space separated names for axes, underscore means skip axis



Type Description

dict, maps axes names to their lengths

Source code in einops/
def parse_shape(x, pattern: str) -> dict:
    Parse a tensor shape to dictionary mapping axes names to their lengths.

    # Use underscore to skip the dimension in parsing.
    >>> x = np.zeros([2, 3, 5, 7])
    >>> parse_shape(x, 'batch _ h w')
    {'batch': 2, 'h': 5, 'w': 7}

    # `parse_shape` output can be used to specify axes_lengths for other operations:
    >>> y = np.zeros([700])
    >>> rearrange(y, '(b c h w) -> b c h w', **parse_shape(x, 'b _ h w')).shape
    (2, 10, 5, 7)


    For symbolic frameworks may return symbols, not integers.

        x: tensor of any supported framework
        pattern: str, space separated names for axes, underscore means skip axis

        dict, maps axes names to their lengths
    exp = ParsedExpression(pattern, allow_underscore=True)
    shape = get_backend(x).shape(x)
    if exp.has_composed_axes():
        raise RuntimeError(f"Can't parse shape with composite axes: {pattern} {shape}")
    if len(shape) != len(exp.composition):
        if exp.has_ellipsis:
            if len(shape) < len(exp.composition) - 1:
                raise RuntimeError(f"Can't parse shape with this number of dimensions: {pattern} {shape}")
            raise RuntimeError(f"Can't parse shape with different number of dimensions: {pattern} {shape}")
    if exp.has_ellipsis:
        ellipsis_idx = exp.composition.index(_ellipsis)
        composition = (
            + ["_"] * (len(shape) - len(exp.composition) + 1)
            + exp.composition[ellipsis_idx + 1 :]
        composition = exp.composition
    result = {}
    for axes, axis_length in zip(composition, shape):  # type: ignore
        # axes either [], or [AnonymousAxis] or ['axis_name']
        if len(axes) == 0:
            if axis_length != 1:
                raise RuntimeError(f"Length of axis is not 1: {pattern} {shape}")
            [axis] = axes
            if isinstance(axis, str):
                if axis != "_":
                    result[axis] = axis_length
                if axis.value != axis_length:
                    raise RuntimeError(f"Length of anonymous axis does not match: {pattern} {shape}")
    return result