einops.einsum
einops.einsum calls einsum operations with einops-style named axes indexing, computing tensor products with an arbitrary number of tensors. Unlike typical einsum syntax, here you must pass tensors first, and then the pattern.
Also, note that rearrange operations such as "(batch chan) out"
,
or singleton axes ()
, are not currently supported.
Examples:
For a given pattern such as:
>>> x, y, z = np.random.randn(3, 20, 20, 20)
>>> output = einsum(x, y, z, "a b c, c b d, a g k -> a b k")
the following formula is computed:
output[a, b, k] =
\sum_{c, d, g} x[a, b, c] * y[c, b, d] * z[a, g, k]
where the summation over c
, d
, and g
is performed
because those axes names do not appear on the right-hand side.
Let's see some additional examples:
# Filter a set of images:
>>> batched_images = np.random.randn(128, 16, 16)
>>> filters = np.random.randn(16, 16, 30)
>>> result = einsum(batched_images, filters,
... "batch h w, h w channel -> batch channel")
>>> result.shape
(128, 30)
# Matrix multiplication, with an unknown input shape:
>>> batch_shape = (50, 30)
>>> data = np.random.randn(*batch_shape, 20)
>>> weights = np.random.randn(10, 20)
>>> result = einsum(weights, data,
... "out_dim in_dim, ... in_dim -> ... out_dim")
>>> result.shape
(50, 30, 10)
# Matrix trace on a single tensor:
>>> matrix = np.random.randn(10, 10)
>>> result = einsum(matrix, "i i ->")
>>> result.shape
()
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tensors_and_pattern
|
Union[Tensor, str]
|
tensors: tensors of any supported library (numpy, tensorflow, pytorch, jax). pattern: string, einsum pattern, with commas separating specifications for each tensor. pattern should be provided after all tensors. |
()
|
Returns:
Type | Description |
---|---|
Tensor
|
Tensor of the same type as input, after processing with einsum. |
Source code in einops/einops.py
841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 |
|