Skip to content

einops.parse_shape

Parse a tensor shape to dictionary mapping axes names to their lengths.

# Use underscore to skip the dimension in parsing.
>>> x = np.zeros([2, 3, 5, 7])
>>> parse_shape(x, 'batch _ h w')
{'batch': 2, 'h': 5, 'w': 7}

# `parse_shape` output can be used to specify axes_lengths for other operations:
>>> y = np.zeros([700])
>>> rearrange(y, '(b c h w) -> b c h w', **parse_shape(x, 'b _ h w')).shape
(2, 10, 5, 7)

For symbolic frameworks may return symbols, not integers.

Parameters:

Name Type Description Default
x Tensor

tensor of any supported framework

required
pattern str

str, space separated names for axes, underscore means skip axis

required

Returns:

Type Description
dict

dict, maps axes names to their lengths

Source code in einops/einops.py
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
def parse_shape(x: Tensor, pattern: str) -> dict:
    """
    Parse a tensor shape to dictionary mapping axes names to their lengths.

    ```python
    # Use underscore to skip the dimension in parsing.
    >>> x = np.zeros([2, 3, 5, 7])
    >>> parse_shape(x, 'batch _ h w')
    {'batch': 2, 'h': 5, 'w': 7}

    # `parse_shape` output can be used to specify axes_lengths for other operations:
    >>> y = np.zeros([700])
    >>> rearrange(y, '(b c h w) -> b c h w', **parse_shape(x, 'b _ h w')).shape
    (2, 10, 5, 7)

    ```

    For symbolic frameworks may return symbols, not integers.

    Parameters:
        x: tensor of any supported framework
        pattern: str, space separated names for axes, underscore means skip axis

    Returns:
        dict, maps axes names to their lengths
    """
    exp = ParsedExpression(pattern, allow_underscore=True)
    shape = get_backend(x).shape(x)
    if exp.has_composed_axes():
        raise RuntimeError(f"Can't parse shape with composite axes: {pattern} {shape}")
    if len(shape) != len(exp.composition):
        if exp.has_ellipsis:
            if len(shape) < len(exp.composition) - 1:
                raise RuntimeError(f"Can't parse shape with this number of dimensions: {pattern} {shape}")
        else:
            raise RuntimeError(f"Can't parse shape with different number of dimensions: {pattern} {shape}")
    if exp.has_ellipsis:
        ellipsis_idx = exp.composition.index(_ellipsis)
        composition = (
            exp.composition[:ellipsis_idx]
            + ["_"] * (len(shape) - len(exp.composition) + 1)
            + exp.composition[ellipsis_idx + 1 :]
        )
    else:
        composition = exp.composition
    result = {}
    for axes, axis_length in zip(composition, shape):  # type: ignore
        # axes either [], or [AnonymousAxis] or ['axis_name']
        if len(axes) == 0:
            if axis_length != 1:
                raise RuntimeError(f"Length of axis is not 1: {pattern} {shape}")
        else:
            [axis] = axes
            if isinstance(axis, str):
                if axis != "_":
                    result[axis] = axis_length
            else:
                if axis.value != axis_length:
                    raise RuntimeError(f"Length of anonymous axis does not match: {pattern} {shape}")
    return result