einops.einsum
einops.einsum calls einsum operations with einopsstyle named axes indexing, computing tensor products with an arbitrary number of tensors. Unlike typical einsum syntax, here you must pass tensors first, and then the pattern.
Also, note that rearrange operations such as "(batch chan) out"
,
or singleton axes ()
, are not currently supported.
Examples:
For a given pattern such as:
>>> x, y, z = np.random.randn(3, 20, 20, 20)
>>> output = einsum(x, y, z, "a b c, c b d, a g k > a b k")
the following formula is computed:
output[a, b, k] =
\sum_{c, d, g} x[a, b, c] * y[c, b, d] * z[a, g, k]
where the summation over c
, d
, and g
is performed
because those axes names do not appear on the righthand side.
Let's see some additional examples:
# Filter a set of images:
>>> batched_images = np.random.randn(128, 16, 16)
>>> filters = np.random.randn(16, 16, 30)
>>> result = einsum(batched_images, filters,
... "batch h w, h w channel > batch channel")
>>> result.shape
(128, 30)
# Matrix multiplication, with an unknown input shape:
>>> batch_shape = (50, 30)
>>> data = np.random.randn(*batch_shape, 20)
>>> weights = np.random.randn(10, 20)
>>> result = einsum(weights, data,
... "out_dim in_dim, ... in_dim > ... out_dim")
>>> result.shape
(50, 30, 10)
# Matrix trace on a single tensor:
>>> matrix = np.random.randn(10, 10)
>>> result = einsum(matrix, "i i >")
>>> result.shape
()
Parameters:
Name  Type  Description  Default 

tensors_and_pattern 
Union[Tensor, str]

tensors: tensors of any supported library (numpy, tensorflow, pytorch, jax). pattern: string, einsum pattern, with commas separating specifications for each tensor. pattern should be provided after all tensors. 
()

Returns:
Type  Description 

Tensor

Tensor of the same type as input, after processing with einsum. 
Source code in einops/einops.py
