Skip to content

einops.parse_shape

Parse a tensor shape to dictionary mapping axes names to their lengths.

# Use underscore to skip the dimension in parsing.
>>> x = np.zeros([2, 3, 5, 7])
>>> parse_shape(x, 'batch _ h w')
{'batch': 2, 'h': 5, 'w': 7}

# `parse_shape` output can be used to specify axes_lengths for other operations:
>>> y = np.zeros([700])
>>> rearrange(y, '(b c h w) -> b c h w', **parse_shape(x, 'b _ h w')).shape
(2, 10, 5, 7)

For symbolic frameworks may return symbols, not integers.

Parameters:

Name Type Description Default
x

tensor of any of supported frameworks

required
pattern str

str, space separated names for axes, underscore means skip axis

required

Returns:

Type Description

dict, maps axes names to their lengths

Source code in einops/einops.py
def parse_shape(x, pattern: str):
    """
    Parse a tensor shape to dictionary mapping axes names to their lengths.

    ```python
    # Use underscore to skip the dimension in parsing.
    >>> x = np.zeros([2, 3, 5, 7])
    >>> parse_shape(x, 'batch _ h w')
    {'batch': 2, 'h': 5, 'w': 7}

    # `parse_shape` output can be used to specify axes_lengths for other operations:
    >>> y = np.zeros([700])
    >>> rearrange(y, '(b c h w) -> b c h w', **parse_shape(x, 'b _ h w')).shape
    (2, 10, 5, 7)

    ```

    For symbolic frameworks may return symbols, not integers.

    Parameters:
        x: tensor of any of supported frameworks
        pattern: str, space separated names for axes, underscore means skip axis

    Returns:
        dict, maps axes names to their lengths
    """
    exp = ParsedExpression(pattern, allow_underscore=True)
    shape = get_backend(x).shape(x)
    if exp.has_composed_axes():
        raise RuntimeError("Can't parse shape with composite axes: {pattern} {shape}".format(
            pattern=pattern, shape=shape))
    if len(shape) != len(exp.composition):
        if exp.has_ellipsis:
            if len(shape) < len(exp.composition) - 1:
                raise RuntimeError("Can't parse shape with this number of dimensions: {pattern} {shape}".format(
                    pattern=pattern, shape=shape))
        else:
            raise RuntimeError("Can't parse shape with different number of dimensions: {pattern} {shape}".format(
                pattern=pattern, shape=shape))
    if exp.has_ellipsis:
        ellipsis_idx = exp.composition.index(_ellipsis)
        composition = (exp.composition[:ellipsis_idx] +
                       ['_'] * (len(shape) - len(exp.composition) + 1) +
                       exp.composition[ellipsis_idx+1:])
    else:
        composition = exp.composition
    result = {}
    for (axis_name, ), axis_length in zip(composition, shape):
        if axis_name != '_':
            result[axis_name] = axis_length
    return result