2025-08-11 12:24:21 +08:00

416 lines
13 KiB
Python

"""A functionally equivalent parser of the numpy.einsum input parser."""
import itertools
from typing import Any, Dict, Iterator, List, Sequence, Tuple
from opt_einsum.typing import ArrayType, TensorShapeType
__all__ = [
"is_valid_einsum_char",
"has_valid_einsum_chars_only",
"get_symbol",
"get_shape",
"gen_unused_symbols",
"convert_to_valid_einsum_chars",
"alpha_canonicalize",
"find_output_str",
"find_output_shape",
"possibly_convert_to_numpy",
"parse_einsum_input",
]
_einsum_symbols_base = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
def is_valid_einsum_char(x: str) -> bool:
"""Check if the character ``x`` is valid for numpy einsum.
**Examples:**
```python
is_valid_einsum_char("a")
#> True
is_valid_einsum_char("Ǵ")
#> False
```
"""
return (x in _einsum_symbols_base) or (x in ",->.")
def has_valid_einsum_chars_only(einsum_str: str) -> bool:
"""Check if ``einsum_str`` contains only valid characters for numpy einsum.
**Examples:**
```python
has_valid_einsum_chars_only("abAZ")
#> True
has_valid_einsum_chars_only("Över")
#> False
```
"""
return all(map(is_valid_einsum_char, einsum_str))
def get_symbol(i: int) -> str:
"""Get the symbol corresponding to int ``i`` - runs through the usual 52
letters before resorting to unicode characters, starting at ``chr(192)`` and skipping surrogates.
**Examples:**
```python
get_symbol(2)
#> 'c'
get_symbol(200)
#> 'Ŕ'
get_symbol(20000)
#> ''
```
"""
if i < 52:
return _einsum_symbols_base[i]
elif i >= 55296:
# Skip chr(57343) - chr(55296) as surrogates
return chr(i + 2048)
else:
return chr(i + 140)
def gen_unused_symbols(used: str, n: int) -> Iterator[str]:
"""Generate ``n`` symbols that are not already in ``used``.
**Examples:**
```python
list(oe.parser.gen_unused_symbols("abd", 2))
#> ['c', 'e']
```
"""
i = cnt = 0
while cnt < n:
s = get_symbol(i)
i += 1
if s in used:
continue
yield s
cnt += 1
def convert_to_valid_einsum_chars(einsum_str: str) -> str:
"""Convert the str ``einsum_str`` to contain only the alphabetic characters
valid for numpy einsum. If there are too many symbols, let the backend
throw an error.
Examples:
--------
>>> oe.parser.convert_to_valid_einsum_chars("Ĥěļļö")
'cbdda'
"""
symbols = sorted(set(einsum_str) - set(",->"))
replacer = {x: get_symbol(i) for i, x in enumerate(symbols)}
return "".join(replacer.get(x, x) for x in einsum_str)
def alpha_canonicalize(equation: str) -> str:
"""Alpha convert an equation in an order-independent canonical way.
Examples:
--------
>>> oe.parser.alpha_canonicalize("dcba")
'abcd'
>>> oe.parser.alpha_canonicalize("Ĥěļļö")
'abccd'
"""
rename: Dict[str, str] = {}
for name in equation:
if name in ".,->":
continue
if name not in rename:
rename[name] = get_symbol(len(rename))
return "".join(rename.get(x, x) for x in equation)
def find_output_str(subscripts: str) -> str:
"""Find the output string for the inputs ``subscripts`` under canonical einstein summation rules.
That is, repeated indices are summed over by default.
Examples:
--------
>>> oe.parser.find_output_str("ab,bc")
'ac'
>>> oe.parser.find_output_str("a,b")
'ab'
>>> oe.parser.find_output_str("a,a,b,b")
''
"""
tmp_subscripts = subscripts.replace(",", "")
return "".join(s for s in sorted(set(tmp_subscripts)) if tmp_subscripts.count(s) == 1)
def find_output_shape(inputs: List[str], shapes: List[TensorShapeType], output: str) -> TensorShapeType:
"""Find the output shape for given inputs, shapes and output string, taking
into account broadcasting.
Examples:
--------
>>> oe.parser.find_output_shape(["ab", "bc"], [(2, 3), (3, 4)], "ac")
(2, 4)
# Broadcasting is accounted for
>>> oe.parser.find_output_shape(["a", "a"], [(4, ), (1, )], "a")
(4,)
"""
return tuple(max(shape[loc] for shape, loc in zip(shapes, [x.find(c) for x in inputs]) if loc >= 0) for c in output)
_BaseTypes = (bool, int, float, complex, str, bytes)
def get_shape(x: Any) -> TensorShapeType:
"""Get the shape of the array-like object `x`. If `x` is not array-like, raise an error.
Array-like objects are those that have a `shape` attribute, are sequences of BaseTypes, or are BaseTypes.
BaseTypes are defined as `bool`, `int`, `float`, `complex`, `str`, and `bytes`.
"""
if hasattr(x, "shape"):
return x.shape
elif isinstance(x, _BaseTypes):
return ()
elif isinstance(x, Sequence):
shape = []
while isinstance(x, Sequence) and not isinstance(x, _BaseTypes):
shape.append(len(x))
x = x[0]
return tuple(shape)
else:
raise ValueError(f"Cannot determine the shape of {x}, can only determine the shape of array-like objects.")
def possibly_convert_to_numpy(x: Any) -> Any:
"""Convert things without a 'shape' to ndarrays, but leave everything else.
Examples:
--------
>>> oe.parser.possibly_convert_to_numpy(5)
array(5)
>>> oe.parser.possibly_convert_to_numpy([5, 3])
array([5, 3])
>>> oe.parser.possibly_convert_to_numpy(np.array([5, 3]))
array([5, 3])
# Any class with a shape is passed through
>>> class Shape:
... def __init__(self, shape):
... self.shape = shape
...
>>> myshape = Shape((5, 5))
>>> oe.parser.possibly_convert_to_numpy(myshape)
<__main__.Shape object at 0x10f850710>
"""
if not hasattr(x, "shape"):
try:
import numpy as np # type: ignore
except ModuleNotFoundError:
raise ModuleNotFoundError(
"numpy is required to convert non-array objects to arrays. This function will be deprecated in the future."
)
return np.asanyarray(x)
else:
return x
def convert_subscripts(old_sub: List[Any], symbol_map: Dict[Any, Any]) -> str:
"""Convert user custom subscripts list to subscript string according to `symbol_map`.
Examples:
--------
>>> oe.parser.convert_subscripts(['abc', 'def'], {'abc':'a', 'def':'b'})
'ab'
>>> oe.parser.convert_subscripts([Ellipsis, object], {object:'a'})
'...a'
"""
new_sub = ""
for s in old_sub:
if s is Ellipsis:
new_sub += "..."
else:
# no need to try/except here because symbol_map has already been checked
new_sub += symbol_map[s]
return new_sub
def convert_interleaved_input(operands: Sequence[Any]) -> Tuple[str, Tuple[Any, ...]]:
"""Convert 'interleaved' input to standard einsum input."""
tmp_operands = list(operands)
operand_list = []
subscript_list = []
for _ in range(len(operands) // 2):
operand_list.append(tmp_operands.pop(0))
subscript_list.append(tmp_operands.pop(0))
output_list = tmp_operands[-1] if len(tmp_operands) else None
# build a map from user symbols to single-character symbols based on `get_symbol`
# The map retains the intrinsic order of user symbols
try:
# collect all user symbols
symbol_set = set(itertools.chain.from_iterable(subscript_list))
# remove Ellipsis because it can not be compared with other objects
symbol_set.discard(Ellipsis)
# build the map based on sorted user symbols, retaining the order we lost in the `set`
symbol_map = {symbol: get_symbol(idx) for idx, symbol in enumerate(sorted(symbol_set))}
except TypeError: # unhashable or uncomparable object
raise TypeError(
"For this input type lists must contain either Ellipsis "
"or hashable and comparable object (e.g. int, str)."
)
subscripts = ",".join(convert_subscripts(sub, symbol_map) for sub in subscript_list)
if output_list is not None:
subscripts += "->"
subscripts += convert_subscripts(output_list, symbol_map)
return subscripts, tuple(operand_list)
def parse_einsum_input(operands: Any, shapes: bool = False) -> Tuple[str, str, List[ArrayType]]:
"""A reproduction of einsum c side einsum parsing in python.
Parameters:
operands: Intakes the same inputs as `contract_path`, but NOT the keyword args. The only
supported keyword argument is:
shapes: Whether ``parse_einsum_input`` should assume arrays (the default) or
array shapes have been supplied.
Returns:
input_strings: Parsed input strings
output_string: Parsed output string
operands: The operands to use in the numpy contraction
Examples:
The operand list is simplified to reduce printing:
```python
>>> a = np.random.rand(4, 4)
>>> b = np.random.rand(4, 4, 4)
>>> parse_einsum_input(('...a,...a->...', a, b))
('za,xza', 'xz', [a, b])
>>> parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))
('za,xza', 'xz', [a, b])
```
"""
if len(operands) == 0:
raise ValueError("No input operands")
if isinstance(operands[0], str):
subscripts = operands[0].replace(" ", "")
if shapes:
if any(hasattr(o, "shape") for o in operands[1:]):
raise ValueError(
"shapes is set to True but given at least one operand looks like an array"
" (at least one operand has a shape attribute). "
)
operands = operands[1:]
else:
subscripts, operands = convert_interleaved_input(operands)
if shapes:
operand_shapes = operands
else:
operand_shapes = [get_shape(o) for o in operands]
# Check for proper "->"
if ("-" in subscripts) or (">" in subscripts):
invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
if invalid or (subscripts.count("->") != 1):
raise ValueError("Subscripts can only contain one '->'.")
# Parse ellipses
if "." in subscripts:
used = subscripts.replace(".", "").replace(",", "").replace("->", "")
ellipse_inds = "".join(gen_unused_symbols(used, max(len(x) for x in operand_shapes)))
longest = 0
# Do we have an output to account for?
if "->" in subscripts:
input_tmp, output_sub = subscripts.split("->")
split_subscripts = input_tmp.split(",")
out_sub = True
else:
split_subscripts = subscripts.split(",")
out_sub = False
for num, sub in enumerate(split_subscripts):
if "." in sub:
if (sub.count(".") != 3) or (sub.count("...") != 1):
raise ValueError("Invalid Ellipses.")
# Take into account numerical values
if operand_shapes[num] == ():
ellipse_count = 0
else:
ellipse_count = max(len(operand_shapes[num]), 1) - (len(sub) - 3)
if ellipse_count > longest:
longest = ellipse_count
if ellipse_count < 0:
raise ValueError("Ellipses lengths do not match.")
elif ellipse_count == 0:
split_subscripts[num] = sub.replace("...", "")
else:
split_subscripts[num] = sub.replace("...", ellipse_inds[-ellipse_count:])
subscripts = ",".join(split_subscripts)
# Figure out output ellipses
if longest == 0:
out_ellipse = ""
else:
out_ellipse = ellipse_inds[-longest:]
if out_sub:
subscripts += "->" + output_sub.replace("...", out_ellipse)
else:
# Special care for outputless ellipses
output_subscript = find_output_str(subscripts)
normal_inds = "".join(sorted(set(output_subscript) - set(out_ellipse)))
subscripts += "->" + out_ellipse + normal_inds
# Build output string if does not exist
if "->" in subscripts:
input_subscripts, output_subscript = subscripts.split("->")
else:
input_subscripts, output_subscript = subscripts, find_output_str(subscripts)
# Make sure output subscripts are unique and in the input
for char in output_subscript:
if output_subscript.count(char) != 1:
raise ValueError(f"Output character '{char}' appeared more than once in the output.")
if char not in input_subscripts:
raise ValueError(f"Output character '{char}' did not appear in the input")
# Make sure number operands is equivalent to the number of terms
if len(input_subscripts.split(",")) != len(operands):
raise ValueError(
f"Number of einsum subscripts, {len(input_subscripts.split(','))}, must be equal to the "
f"number of operands, {len(operands)}."
)
return input_subscripts, output_subscript, operands