Parse a tensor shape to dictionary mapping axes names to their lengths.
# Use underscore to skip the dimension in parsing.
>>> x = np.zeros([2, 3, 5, 7])
>>> parse_shape(x, 'batch _ h w')
{'batch': 2, 'h': 5, 'w': 7}
# `parse_shape` output can be used to specify axes_lengths for other operations:
>>> y = np.zeros([700])
>>> rearrange(y, '(b c h w) -> b c h w', **parse_shape(x, 'b _ h w')).shape
(2, 10, 5, 7)
For symbolic frameworks may return symbols, not integers.
Parameters:
Name |
Type |
Description |
Default |
x
|
Tensor
|
tensor of any supported framework
|
required
|
pattern
|
str
|
str, space separated names for axes, underscore means skip axis
|
required
|
Returns:
Type |
Description |
dict
|
dict, maps axes names to their lengths
|
Source code in einops/einops.py
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712 | def parse_shape(x: Tensor, pattern: str) -> dict:
"""
Parse a tensor shape to dictionary mapping axes names to their lengths.
```python
# Use underscore to skip the dimension in parsing.
>>> x = np.zeros([2, 3, 5, 7])
>>> parse_shape(x, 'batch _ h w')
{'batch': 2, 'h': 5, 'w': 7}
# `parse_shape` output can be used to specify axes_lengths for other operations:
>>> y = np.zeros([700])
>>> rearrange(y, '(b c h w) -> b c h w', **parse_shape(x, 'b _ h w')).shape
(2, 10, 5, 7)
```
For symbolic frameworks may return symbols, not integers.
Parameters:
x: tensor of any supported framework
pattern: str, space separated names for axes, underscore means skip axis
Returns:
dict, maps axes names to their lengths
"""
exp = ParsedExpression(pattern, allow_underscore=True)
shape = get_backend(x).shape(x)
if exp.has_composed_axes():
raise RuntimeError(f"Can't parse shape with composite axes: {pattern} {shape}")
if len(shape) != len(exp.composition):
if exp.has_ellipsis:
if len(shape) < len(exp.composition) - 1:
raise RuntimeError(f"Can't parse shape with this number of dimensions: {pattern} {shape}")
else:
raise RuntimeError(f"Can't parse shape with different number of dimensions: {pattern} {shape}")
if exp.has_ellipsis:
ellipsis_idx = exp.composition.index(_ellipsis)
composition = (
exp.composition[:ellipsis_idx]
+ ["_"] * (len(shape) - len(exp.composition) + 1)
+ exp.composition[ellipsis_idx + 1 :]
)
else:
composition = exp.composition
result = {}
for axes, axis_length in zip(composition, shape): # type: ignore
# axes either [], or [AnonymousAxis] or ['axis_name']
if len(axes) == 0:
if axis_length != 1:
raise RuntimeError(f"Length of axis is not 1: {pattern} {shape}")
else:
[axis] = axes
if isinstance(axis, str):
if axis != "_":
result[axis] = axis_length
else:
if axis.value != axis_length:
raise RuntimeError(f"Length of anonymous axis does not match: {pattern} {shape}")
return result
|