einops.parse_shape
Parse a tensor shape to dictionary mapping axes names to their lengths.
# Use underscore to skip the dimension in parsing.
>>> x = np.zeros([2, 3, 5, 7])
>>> parse_shape(x, 'batch _ h w')
{'batch': 2, 'h': 5, 'w': 7}
# `parse_shape` output can be used to specify axes_lengths for other operations:
>>> y = np.zeros([700])
>>> rearrange(y, '(b c h w) -> b c h w', **parse_shape(x, 'b _ h w')).shape
(2, 10, 5, 7)
For symbolic frameworks may return symbols, not integers.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
tensor of any supported framework |
required | |
pattern |
str |
str, space separated names for axes, underscore means skip axis |
required |
Returns:
Type | Description |
---|---|
dict |
dict, maps axes names to their lengths |
Source code in einops/einops.py
def parse_shape(x, pattern: str) -> dict:
"""
Parse a tensor shape to dictionary mapping axes names to their lengths.
```python
# Use underscore to skip the dimension in parsing.
>>> x = np.zeros([2, 3, 5, 7])
>>> parse_shape(x, 'batch _ h w')
{'batch': 2, 'h': 5, 'w': 7}
# `parse_shape` output can be used to specify axes_lengths for other operations:
>>> y = np.zeros([700])
>>> rearrange(y, '(b c h w) -> b c h w', **parse_shape(x, 'b _ h w')).shape
(2, 10, 5, 7)
```
For symbolic frameworks may return symbols, not integers.
Parameters:
x: tensor of any supported framework
pattern: str, space separated names for axes, underscore means skip axis
Returns:
dict, maps axes names to their lengths
"""
exp = ParsedExpression(pattern, allow_underscore=True)
shape = get_backend(x).shape(x)
if exp.has_composed_axes():
raise RuntimeError(f"Can't parse shape with composite axes: {pattern} {shape}")
if len(shape) != len(exp.composition):
if exp.has_ellipsis:
if len(shape) < len(exp.composition) - 1:
raise RuntimeError(f"Can't parse shape with this number of dimensions: {pattern} {shape}")
else:
raise RuntimeError(f"Can't parse shape with different number of dimensions: {pattern} {shape}")
if exp.has_ellipsis:
ellipsis_idx = exp.composition.index(_ellipsis)
composition = (
exp.composition[:ellipsis_idx]
+ ["_"] * (len(shape) - len(exp.composition) + 1)
+ exp.composition[ellipsis_idx + 1 :]
)
else:
composition = exp.composition
result = {}
for axes, axis_length in zip(composition, shape): # type: ignore
# axes either [], or [AnonymousAxis] or ['axis_name']
if len(axes) == 0:
if axis_length != 1:
raise RuntimeError(f"Length of axis is not 1: {pattern} {shape}")
else:
[axis] = axes
if isinstance(axis, str):
if axis != "_":
result[axis] = axis_length
else:
if axis.value != axis_length:
raise RuntimeError(f"Length of anonymous axis does not match: {pattern} {shape}")
return result