Skip to content

einops.parse_shape

Parse a tensor shape to dictionary mapping axes names to their lengths.

# Use underscore to skip the dimension in parsing.
>>> x = np.zeros([2, 3, 5, 7])
>>> parse_shape(x, 'batch _ h w')
{'batch': 2, 'h': 5, 'w': 7}

# `parse_shape` output can be used to specify axes_lengths for other operations:
>>> y = np.zeros([700])
>>> rearrange(y, '(b c h w) -> b c h w', **parse_shape(x, 'b _ h w')).shape
(2, 10, 5, 7)

For symbolic frameworks may return symbols, not integers.

Parameters:

Name Type Description Default
x Tensor

tensor of any supported framework

required
pattern str

str, space separated names for axes, underscore means skip axis

required

Returns:

Type Description
dict

dict, maps axes names to their lengths

Source code in einops/einops.py
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
def parse_shape(x: Tensor, pattern: str) -> dict:
    """
    Parse a tensor shape to dictionary mapping axes names to their lengths.

    ```python
    # Use underscore to skip the dimension in parsing.
    >>> x = np.zeros([2, 3, 5, 7])
    >>> parse_shape(x, 'batch _ h w')
    {'batch': 2, 'h': 5, 'w': 7}

    # `parse_shape` output can be used to specify axes_lengths for other operations:
    >>> y = np.zeros([700])
    >>> rearrange(y, '(b c h w) -> b c h w', **parse_shape(x, 'b _ h w')).shape
    (2, 10, 5, 7)

    ```

    For symbolic frameworks may return symbols, not integers.

    Parameters:
        x: tensor of any supported framework
        pattern: str, space separated names for axes, underscore means skip axis

    Returns:
        dict, maps axes names to their lengths
    """
    exp = ParsedExpression(pattern, allow_underscore=True)
    shape = get_backend(x).shape(x)
    if exp.has_composed_axes():
        raise RuntimeError(f"Can't parse shape with composite axes: {pattern} {shape}")
    if len(shape) != len(exp.composition):
        if exp.has_ellipsis:
            if len(shape) < len(exp.composition) - 1:
                raise RuntimeError(f"Can't parse shape with this number of dimensions: {pattern} {shape}")
        else:
            raise RuntimeError(f"Can't parse shape with different number of dimensions: {pattern} {shape}")
    if exp.has_ellipsis:
        ellipsis_idx = exp.composition.index(_ellipsis)
        composition = (
            exp.composition[:ellipsis_idx]
            + ["_"] * (len(shape) - len(exp.composition) + 1)
            + exp.composition[ellipsis_idx + 1 :]
        )
    else:
        composition = exp.composition
    result = {}
    for axes, axis_length in zip(composition, shape):  # type: ignore
        # axes either [], or [AnonymousAxis] or ['axis_name']
        if len(axes) == 0:
            if axis_length != 1:
                raise RuntimeError(f"Length of axis is not 1: {pattern} {shape}")
        else:
            [axis] = axes
            if isinstance(axis, str):
                if axis != "_":
                    result[axis] = axis_length
            else:
                if axis.value != axis_length:
                    raise RuntimeError(f"Length of anonymous axis does not match: {pattern} {shape}")
    return result