einops.reduce
einops.reduce provides combination of reordering and reduction using reader-friendly notation.
Examples for reduce operation:
>>> x = np.random.randn(100, 32, 64)
# perform max-reduction on the first axis
>>> y = reduce(x, 't b c -> b c', 'max')
# same as previous, but with clearer axes meaning
>>> y = reduce(x, 'time batch channel -> batch channel', 'max')
>>> x = np.random.randn(10, 20, 30, 40)
# 2d max-pooling with kernel size = 2 * 2 for image processing
>>> y1 = reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h2=2, w2=2)
# if one wants to go back to the original height and width, depth-to-space trick can be applied
>>> y2 = rearrange(y1, 'b (c h2 w2) h1 w1 -> b c (h1 h2) (w1 w2)', h2=2, w2=2)
>>> assert parse_shape(x, 'b _ h w') == parse_shape(y2, 'b _ h w')
# Adaptive 2d max-pooling to 3 * 4 grid
>>> reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h1=3, w1=4).shape
(10, 20, 3, 4)
# Global average pooling
>>> reduce(x, 'b c h w -> b c', 'mean').shape
(10, 20)
# Subtracting mean over batch for each channel
>>> y = x - reduce(x, 'b c h w -> () c () ()', 'mean')
# Subtracting per-image mean for each channel
>>> y = x - reduce(x, 'b c h w -> b c () ()', 'mean')
Parameters:
Name | Type | Description | Default |
---|---|---|---|
tensor |
Union[~Tensor, List[~Tensor]] |
tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch). list of tensors is also accepted, those should be of the same type and shape |
required |
pattern |
str |
string, reduction pattern |
required |
reduction |
Union[str, Callable[[~Tensor, Tuple[int, ...]], ~Tensor]] |
one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive alternatively, a callable f(tensor, reduced_axes) -> tensor can be provided. This allows using various reductions, examples: np.max, tf.reduce_logsumexp, torch.var, etc. |
required |
axes_lengths |
int |
any additional specifications for dimensions |
{} |
Returns:
Type | Description |
---|---|
~Tensor |
tensor of the same type as input |
Source code in einops/einops.py
def reduce(tensor: Union[Tensor, List[Tensor]], pattern: str, reduction: Reduction, **axes_lengths: int) -> Tensor:
"""
einops.reduce provides combination of reordering and reduction using reader-friendly notation.
Examples for reduce operation:
```python
>>> x = np.random.randn(100, 32, 64)
# perform max-reduction on the first axis
>>> y = reduce(x, 't b c -> b c', 'max')
# same as previous, but with clearer axes meaning
>>> y = reduce(x, 'time batch channel -> batch channel', 'max')
>>> x = np.random.randn(10, 20, 30, 40)
# 2d max-pooling with kernel size = 2 * 2 for image processing
>>> y1 = reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h2=2, w2=2)
# if one wants to go back to the original height and width, depth-to-space trick can be applied
>>> y2 = rearrange(y1, 'b (c h2 w2) h1 w1 -> b c (h1 h2) (w1 w2)', h2=2, w2=2)
>>> assert parse_shape(x, 'b _ h w') == parse_shape(y2, 'b _ h w')
# Adaptive 2d max-pooling to 3 * 4 grid
>>> reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h1=3, w1=4).shape
(10, 20, 3, 4)
# Global average pooling
>>> reduce(x, 'b c h w -> b c', 'mean').shape
(10, 20)
# Subtracting mean over batch for each channel
>>> y = x - reduce(x, 'b c h w -> () c () ()', 'mean')
# Subtracting per-image mean for each channel
>>> y = x - reduce(x, 'b c h w -> b c () ()', 'mean')
```
Parameters:
tensor: tensor: tensor of any supported library (e.g. numpy.ndarray, tensorflow, pytorch).
list of tensors is also accepted, those should be of the same type and shape
pattern: string, reduction pattern
reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive
alternatively, a callable f(tensor, reduced_axes) -> tensor can be provided.
This allows using various reductions, examples: np.max, tf.reduce_logsumexp, torch.var, etc.
axes_lengths: any additional specifications for dimensions
Returns:
tensor of the same type as input
"""
try:
if isinstance(tensor, list):
if len(tensor) == 0:
raise TypeError("Rearrange/Reduce/Repeat can't be applied to an empty list")
backend = get_backend(tensor[0])
tensor = backend.stack_on_zeroth_dimension(tensor)
else:
backend = get_backend(tensor)
hashable_axes_lengths = tuple(axes_lengths.items())
shape = backend.shape(tensor)
recipe = _prepare_transformation_recipe(pattern, reduction, axes_names=tuple(axes_lengths), ndim=len(shape))
return _apply_recipe(
backend, recipe, cast(Tensor, tensor), reduction_type=reduction, axes_lengths=hashable_axes_lengths
)
except EinopsError as e:
message = ' Error while processing {}-reduction pattern "{}".'.format(reduction, pattern)
if not isinstance(tensor, list):
message += "\n Input tensor shape: {}. ".format(shape)
else:
message += "\n Input is list. "
message += "Additional info: {}.".format(axes_lengths)
raise EinopsError(message + "\n {}".format(e))