Skip to content

einops.einsum

einops.einsum calls einsum operations with einops-style named axes indexing, computing tensor products with an arbitrary number of tensors. Unlike typical einsum syntax, here you must pass tensors first, and then the pattern.

Also, note that rearrange operations such as "(batch chan) out", or singleton axes (), are not currently supported.

Examples:

For a given pattern such as:

>>> x, y, z = np.random.randn(3, 20, 20, 20)
>>> output = einsum(x, y, z, "a b c, c b d, a g k -> a b k")

the following formula is computed:

output[a, b, k] =
    \sum_{c, d, g} x[a, b, c] * y[c, b, d] * z[a, g, k]

where the summation over c, d, and g is performed because those axes names do not appear on the right-hand side.

Let's see some additional examples:

# Filter a set of images:
>>> batched_images = np.random.randn(128, 16, 16)
>>> filters = np.random.randn(16, 16, 30)
>>> result = einsum(batched_images, filters,
...                 "batch h w, h w channel -> batch channel")
>>> result.shape
(128, 30)

# Matrix multiplication, with an unknown input shape:
>>> batch_shape = (50, 30)
>>> data = np.random.randn(*batch_shape, 20)
>>> weights = np.random.randn(10, 20)
>>> result = einsum(weights, data,
...                 "out_dim in_dim, ... in_dim -> ... out_dim")
>>> result.shape
(50, 30, 10)

# Matrix trace on a single tensor:
>>> matrix = np.random.randn(10, 10)
>>> result = einsum(matrix, "i i ->")
>>> result.shape
()

Parameters:

Name Type Description Default
tensors_and_pattern Union[Tensor, str]

tensors: tensors of any supported library (numpy, tensorflow, pytorch, jax). pattern: string, einsum pattern, with commas separating specifications for each tensor. pattern should be provided after all tensors.

()

Returns:

Type Description
Tensor

Tensor of the same type as input, after processing with einsum.

Source code in einops/einops.py
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
def einsum(*tensors_and_pattern: Union[Tensor, str]) -> Tensor:
    r"""
    einops.einsum calls einsum operations with einops-style named
    axes indexing, computing tensor products with an arbitrary
    number of tensors. Unlike typical einsum syntax, here you must
    pass tensors first, and then the pattern.

    Also, note that rearrange operations such as `"(batch chan) out"`,
    or singleton axes `()`, are not currently supported.

    Examples:

    For a given pattern such as:
    ```python
    >>> x, y, z = np.random.randn(3, 20, 20, 20)
    >>> output = einsum(x, y, z, "a b c, c b d, a g k -> a b k")

    ```
    the following formula is computed:
    ```tex
    output[a, b, k] =
        \sum_{c, d, g} x[a, b, c] * y[c, b, d] * z[a, g, k]
    ```
    where the summation over `c`, `d`, and `g` is performed
    because those axes names do not appear on the right-hand side.

    Let's see some additional examples:
    ```python
    # Filter a set of images:
    >>> batched_images = np.random.randn(128, 16, 16)
    >>> filters = np.random.randn(16, 16, 30)
    >>> result = einsum(batched_images, filters,
    ...                 "batch h w, h w channel -> batch channel")
    >>> result.shape
    (128, 30)

    # Matrix multiplication, with an unknown input shape:
    >>> batch_shape = (50, 30)
    >>> data = np.random.randn(*batch_shape, 20)
    >>> weights = np.random.randn(10, 20)
    >>> result = einsum(weights, data,
    ...                 "out_dim in_dim, ... in_dim -> ... out_dim")
    >>> result.shape
    (50, 30, 10)

    # Matrix trace on a single tensor:
    >>> matrix = np.random.randn(10, 10)
    >>> result = einsum(matrix, "i i ->")
    >>> result.shape
    ()

    ```

    Parameters:
        tensors_and_pattern:
            tensors: tensors of any supported library (numpy, tensorflow, pytorch, jax).
            pattern: string, einsum pattern, with commas
                separating specifications for each tensor.
                pattern should be provided after all tensors.

    Returns:
        Tensor of the same type as input, after processing with einsum.

    """
    if len(tensors_and_pattern) <= 1:
        raise ValueError(
            "`einops.einsum` takes at minimum two arguments: the tensors (at least one), followed by the pattern."
        )
    pattern = tensors_and_pattern[-1]
    if not isinstance(pattern, str):
        raise ValueError(
            "The last argument passed to `einops.einsum` must be a string, representing the einsum pattern."
        )
    tensors = tensors_and_pattern[:-1]
    pattern = _compactify_pattern_for_einsum(pattern)
    return get_backend(tensors[0]).einsum(pattern, *tensors)