2025-08-11 12:24:21 +08:00

535 lines
19 KiB
Python

"""
Tests the accuracy of the opt_einsum paths in addition to unit tests for
the various path helper functions.
"""
import itertools
from concurrent.futures import ProcessPoolExecutor
from typing import Any, Dict, List, Optional
import pytest
import opt_einsum as oe
from opt_einsum.testing import build_shapes, rand_equation
from opt_einsum.typing import ArrayIndexType, OptimizeKind, PathType, TensorShapeType
explicit_path_tests = {
"GEMM1": (
[set("abd"), set("ac"), set("bdc")],
set(""),
{"a": 1, "b": 2, "c": 3, "d": 4},
),
"Inner1": (
[set("abcd"), set("abc"), set("bc")],
set(""),
{"a": 5, "b": 2, "c": 3, "d": 4},
),
}
# note that these tests have no unique solution due to the chosen dimensions
path_edge_tests = [
["greedy", "eb,cb,fb->cef", ((0, 2), (0, 1))],
["branch-all", "eb,cb,fb->cef", ((0, 2), (0, 1))],
["branch-2", "eb,cb,fb->cef", ((0, 2), (0, 1))],
["optimal", "eb,cb,fb->cef", ((0, 2), (0, 1))],
["dp", "eb,cb,fb->cef", ((1, 2), (0, 1))],
["greedy", "dd,fb,be,cdb->cef", ((0, 3), (0, 1), (0, 1))],
["branch-all", "dd,fb,be,cdb->cef", ((0, 3), (0, 1), (0, 1))],
["branch-2", "dd,fb,be,cdb->cef", ((0, 3), (0, 1), (0, 1))],
["optimal", "dd,fb,be,cdb->cef", ((0, 3), (0, 1), (0, 1))],
["optimal", "dd,fb,be,cdb->cef", ((0, 3), (0, 1), (0, 1))],
["dp", "dd,fb,be,cdb->cef", ((0, 3), (0, 2), (0, 1))],
["greedy", "bca,cdb,dbf,afc->", ((1, 2), (0, 2), (0, 1))],
["branch-all", "bca,cdb,dbf,afc->", ((1, 2), (0, 2), (0, 1))],
["branch-2", "bca,cdb,dbf,afc->", ((1, 2), (0, 2), (0, 1))],
["optimal", "bca,cdb,dbf,afc->", ((1, 2), (0, 2), (0, 1))],
["dp", "bca,cdb,dbf,afc->", ((1, 2), (1, 2), (0, 1))],
["greedy", "dcc,fce,ea,dbf->ab", ((1, 2), (0, 1), (0, 1))],
["branch-all", "dcc,fce,ea,dbf->ab", ((1, 2), (0, 2), (0, 1))],
["branch-2", "dcc,fce,ea,dbf->ab", ((1, 2), (0, 2), (0, 1))],
["optimal", "dcc,fce,ea,dbf->ab", ((1, 2), (0, 2), (0, 1))],
["dp", "dcc,fce,ea,dbf->ab", ((1, 2), (0, 2), (0, 1))],
]
# note that these tests have no unique solution due to the chosen dimensions
path_scalar_tests = [
[
"a,->a",
1,
],
["ab,->ab", 1],
[",a,->a", 2],
[",,a,->a", 3],
[",,->", 2],
]
def check_path(test_output: PathType, benchmark: PathType, bypass: bool = False) -> bool:
if not isinstance(test_output, list):
return False
if len(test_output) != len(benchmark):
return False
ret = True
for pos in range(len(test_output)):
ret &= isinstance(test_output[pos], tuple)
ret &= test_output[pos] == list(benchmark)[pos]
return ret
def assert_contract_order(func: Any, test_data: Any, max_size: int, benchmark: PathType) -> None:
test_output = func(test_data[0], test_data[1], test_data[2], max_size)
assert check_path(test_output, benchmark)
def test_size_by_dict() -> None:
sizes_dict = {}
for ind, val in zip("abcdez", [2, 5, 9, 11, 13, 0]):
sizes_dict[ind] = val
path_func = oe.helpers.compute_size_by_dict
assert 1 == path_func("", sizes_dict)
assert 2 == path_func("a", sizes_dict)
assert 5 == path_func("b", sizes_dict)
assert 0 == path_func("z", sizes_dict)
assert 0 == path_func("az", sizes_dict)
assert 0 == path_func("zbc", sizes_dict)
assert 104 == path_func("aaae", sizes_dict)
assert 12870 == path_func("abcde", sizes_dict)
def test_flop_cost() -> None:
size_dict = {v: 10 for v in "abcdef"}
# Loop over an array
assert 10 == oe.helpers.flop_count("a", False, 1, size_dict)
# Hadamard product (*)
assert 10 == oe.helpers.flop_count("a", False, 2, size_dict)
assert 100 == oe.helpers.flop_count("ab", False, 2, size_dict)
# Inner product (+, *)
assert 20 == oe.helpers.flop_count("a", True, 2, size_dict)
assert 200 == oe.helpers.flop_count("ab", True, 2, size_dict)
# Inner product x3 (+, *, *)
assert 30 == oe.helpers.flop_count("a", True, 3, size_dict)
# GEMM
assert 2000 == oe.helpers.flop_count("abc", True, 2, size_dict)
def test_bad_path_option() -> None:
with pytest.raises(KeyError):
oe.contract("a,b,c", [1], [2], [3], optimize="optimall", shapes=True) # type: ignore
def test_explicit_path() -> None:
pytest.importorskip("numpy")
x = oe.contract("a,b,c", [1], [2], [3], optimize=[(1, 2), (0, 1)])
assert x.item() == 6
def test_path_optimal() -> None:
test_func = oe.paths.optimal
test_data = explicit_path_tests["GEMM1"]
assert_contract_order(test_func, test_data, 5000, [(0, 2), (0, 1)])
assert_contract_order(test_func, test_data, 0, [(0, 1, 2)])
def test_path_greedy() -> None:
test_func = oe.paths.greedy
test_data = explicit_path_tests["GEMM1"]
assert_contract_order(test_func, test_data, 5000, [(0, 2), (0, 1)])
assert_contract_order(test_func, test_data, 0, [(0, 1, 2)])
def test_memory_paths() -> None:
expression = "abc,bdef,fghj,cem,mhk,ljk->adgl"
views = build_shapes(expression)
# Test tiny memory limit
path_ret = oe.contract_path(expression, *views, optimize="optimal", memory_limit=5, shapes=True)
assert check_path(path_ret[0], [(0, 1, 2, 3, 4, 5)])
path_ret = oe.contract_path(expression, *views, optimize="greedy", memory_limit=5, shapes=True)
assert check_path(path_ret[0], [(0, 1, 2, 3, 4, 5)])
# Check the possibilities, greedy is capped
path_ret = oe.contract_path(expression, *views, optimize="optimal", memory_limit=-1, shapes=True)
assert check_path(path_ret[0], [(0, 3), (0, 4), (0, 2), (0, 2), (0, 1)])
path_ret = oe.contract_path(expression, *views, optimize="greedy", memory_limit=-1, shapes=True)
assert check_path(path_ret[0], [(0, 3), (0, 4), (0, 2), (0, 2), (0, 1)])
@pytest.mark.parametrize("alg,expression,order", path_edge_tests)
def test_path_edge_cases(alg: OptimizeKind, expression: str, order: PathType) -> None:
views = build_shapes(expression)
# Test tiny memory limit
path_ret = oe.contract_path(expression, *views, optimize=alg, shapes=True)
assert check_path(path_ret[0], order)
@pytest.mark.parametrize("expression,order", path_scalar_tests)
@pytest.mark.parametrize("alg", oe.paths._PATH_OPTIONS)
def test_path_scalar_cases(alg: OptimizeKind, expression: str, order: PathType) -> None:
views = build_shapes(expression)
# Test tiny memory limit
path_ret = oe.contract_path(expression, *views, optimize=alg, shapes=True)
# print(path_ret[0])
assert len(path_ret[0]) == order
def test_optimal_edge_cases() -> None:
# Edge test5
expression = "a,ac,ab,ad,cd,bd,bc->"
edge_test4 = build_shapes(expression, dimension_dict={"a": 20, "b": 20, "c": 20, "d": 20})
path, _ = oe.contract_path(expression, *edge_test4, optimize="greedy", memory_limit="max_input", shapes=True)
assert check_path(path, [(0, 1), (0, 1, 2, 3, 4, 5)])
path, _ = oe.contract_path(expression, *edge_test4, optimize="optimal", memory_limit="max_input", shapes=True)
assert check_path(path, [(0, 1), (0, 1, 2, 3, 4, 5)])
def test_greedy_edge_cases() -> None:
expression = "abc,cfd,dbe,efa"
dim_dict = {k: 20 for k in expression.replace(",", "")}
tensors = build_shapes(expression, dimension_dict=dim_dict)
path, _ = oe.contract_path(expression, *tensors, optimize="greedy", memory_limit="max_input", shapes=True)
assert check_path(path, [(0, 1, 2, 3)])
path, _ = oe.contract_path(expression, *tensors, optimize="greedy", memory_limit=-1, shapes=True)
assert check_path(path, [(0, 1), (0, 2), (0, 1)])
def test_dp_edge_cases_dimension_1() -> None:
eq = "nlp,nlq,pl->n"
shapes = [(1, 1, 1), (1, 1, 1), (1, 1)]
info = oe.contract_path(eq, *shapes, shapes=True, optimize="dp")[1]
assert max(info.scale_list) == 3
def test_dp_edge_cases_all_singlet_indices() -> None:
eq = "a,bcd,efg->"
shapes = [(2,), (2, 2, 2), (2, 2, 2)]
info = oe.contract_path(eq, *shapes, shapes=True, optimize="dp")[1]
assert max(info.scale_list) == 3
def test_custom_dp_can_optimize_for_outer_products() -> None:
eq = "a,b,abc->c"
da, db, dc = 2, 2, 3
shapes = [(da,), (db,), (da, db, dc)]
opt1 = oe.DynamicProgramming(search_outer=False)
opt2 = oe.DynamicProgramming(search_outer=True)
info1 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt1)[1]
info2 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt2)[1]
assert info2.opt_cost < info1.opt_cost
def test_custom_dp_can_optimize_for_size() -> None:
eq, shapes = rand_equation(10, 4, seed=43)
opt1 = oe.DynamicProgramming(minimize="flops")
opt2 = oe.DynamicProgramming(minimize="size")
info1 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt1)[1]
info2 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt2)[1]
assert info1.opt_cost < info2.opt_cost
assert info1.largest_intermediate > info2.largest_intermediate
def test_custom_dp_can_set_cost_cap() -> None:
eq, shapes = rand_equation(5, 3, seed=42)
opt1 = oe.DynamicProgramming(cost_cap=True)
opt2 = oe.DynamicProgramming(cost_cap=False)
opt3 = oe.DynamicProgramming(cost_cap=100)
info1 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt1)[1]
info2 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt2)[1]
info3 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt3)[1]
assert info1.opt_cost == info2.opt_cost == info3.opt_cost
@pytest.mark.parametrize(
"minimize,cost,width,path",
[
("flops", 663054, 18900, [(4, 5), (2, 5), (2, 7), (5, 6), (1, 5), (1, 4), (0, 3), (0, 2), (0, 1)]),
("size", 1114440, 2016, [(2, 7), (3, 8), (3, 7), (2, 6), (1, 5), (1, 4), (1, 3), (1, 2), (0, 1)]),
("write", 983790, 2016, [(0, 8), (3, 4), (1, 4), (5, 6), (1, 5), (0, 4), (0, 3), (1, 2), (0, 1)]),
("combo", 973518, 2016, [(4, 5), (2, 5), (6, 7), (2, 6), (1, 5), (1, 4), (0, 3), (0, 2), (0, 1)]),
("limit", 983832, 2016, [(2, 7), (3, 4), (0, 4), (3, 6), (2, 5), (0, 4), (0, 3), (1, 2), (0, 1)]),
("combo-256", 983790, 2016, [(0, 8), (3, 4), (1, 4), (5, 6), (1, 5), (0, 4), (0, 3), (1, 2), (0, 1)]),
("limit-256", 983832, 2016, [(2, 7), (3, 4), (0, 4), (3, 6), (2, 5), (0, 4), (0, 3), (1, 2), (0, 1)]),
],
)
def test_custom_dp_can_set_minimize(minimize: str, cost: int, width: int, path: PathType) -> None:
eq, shapes = rand_equation(10, 4, seed=43)
opt = oe.DynamicProgramming(minimize=minimize)
info = oe.contract_path(eq, *shapes, shapes=True, optimize=opt)[1]
assert info.path == path
assert info.opt_cost == cost
assert info.largest_intermediate == width
def test_dp_errors_when_no_contractions_found() -> None:
eq, shapes = rand_equation(10, 3, seed=42)
# first get the actual minimum cost
opt = oe.DynamicProgramming(minimize="size")
_, info = oe.contract_path(eq, *shapes, shapes=True, optimize=opt)
mincost = info.largest_intermediate
# check we can still find it without minimizing size explicitly
oe.contract_path(eq, *shapes, shapes=True, memory_limit=mincost, optimize="dp")
# but check just below this threshold raises
with pytest.raises(RuntimeError):
oe.contract_path(eq, *shapes, shapes=True, memory_limit=mincost - 1, optimize="dp")
@pytest.mark.parametrize("optimize", ["greedy", "branch-2", "branch-all", "optimal", "dp"])
def test_can_optimize_outer_products(optimize: OptimizeKind) -> None:
a, b, c = ((10, 10) for _ in range(3))
d = (10, 2)
assert oe.contract_path("ab,cd,ef,fg", a, b, c, d, optimize=optimize, shapes=True)[0] == [
(2, 3),
(0, 2),
(0, 1),
]
@pytest.mark.parametrize("num_symbols", [2, 3, 26, 26 + 26, 256 - 140, 300])
def test_large_path(num_symbols: int) -> None:
symbols = "".join(oe.get_symbol(i) for i in range(num_symbols))
dimension_dict = dict(zip(symbols, itertools.cycle([2, 3, 4])))
expression = ",".join(symbols[t : t + 2] for t in range(num_symbols - 1))
tensors = build_shapes(expression, dimension_dict=dimension_dict)
# Check that path construction does not crash
oe.contract_path(expression, *tensors, optimize="greedy", shapes=True)
def test_custom_random_greedy() -> None:
np = pytest.importorskip("numpy")
eq, shapes = rand_equation(10, 4, seed=42)
views = list(map(np.ones, shapes))
with pytest.raises(ValueError):
oe.RandomGreedy(minimize="something")
optimizer = oe.RandomGreedy(max_repeats=10, minimize="flops")
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
assert len(optimizer.costs) == 10
assert len(optimizer.sizes) == 10
assert path == optimizer.path
assert optimizer.best["flops"] == min(optimizer.costs)
assert path_info.largest_intermediate == optimizer.best["size"]
assert path_info.opt_cost == optimizer.best["flops"]
# check can change settings and run again
optimizer.temperature = 0.0
optimizer.max_repeats = 6
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
assert len(optimizer.costs) == 16
assert len(optimizer.sizes) == 16
assert path == optimizer.path
assert optimizer.best["size"] == min(optimizer.sizes)
assert path_info.largest_intermediate == optimizer.best["size"]
assert path_info.opt_cost == optimizer.best["flops"]
# check error if we try and reuse the optimizer on a different expression
eq, shapes = rand_equation(10, 4, seed=41)
views = list(map(np.ones, shapes))
with pytest.raises(ValueError):
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
def test_custom_branchbound() -> None:
np = pytest.importorskip("numpy")
eq, shapes = rand_equation(8, 4, seed=42)
views = list(map(np.ones, shapes))
optimizer = oe.BranchBound(nbranch=2, cutoff_flops_factor=10, minimize="size")
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
assert path == optimizer.path
assert path_info.largest_intermediate == optimizer.best["size"]
assert path_info.opt_cost == optimizer.best["flops"]
# tweak settings and run again
optimizer.nbranch = 3
optimizer.cutoff_flops_factor = 4
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
assert path == optimizer.path
assert path_info.largest_intermediate == optimizer.best["size"]
assert path_info.opt_cost == optimizer.best["flops"]
# check error if we try and reuse the optimizer on a different expression
eq, shapes = rand_equation(8, 4, seed=41)
views = list(map(np.ones, shapes))
with pytest.raises(ValueError):
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
def test_branchbound_validation() -> None:
with pytest.raises(ValueError):
oe.BranchBound(nbranch=0)
def test_parallel_random_greedy() -> None:
np = pytest.importorskip("numpy")
pool = ProcessPoolExecutor(2)
eq, shapes = rand_equation(10, 4, seed=42)
views = list(map(np.ones, shapes))
optimizer = oe.RandomGreedy(max_repeats=10, parallel=pool)
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
assert len(optimizer.costs) == 10
assert len(optimizer.sizes) == 10
assert path == optimizer.path
assert optimizer.parallel is pool
assert optimizer._executor is pool
assert optimizer.best["flops"] == min(optimizer.costs)
assert path_info.largest_intermediate == optimizer.best["size"]
assert path_info.opt_cost == optimizer.best["flops"]
# now switch to max time algorithm
optimizer.max_repeats = int(1e6)
optimizer.max_time = 0.2
optimizer.parallel = 2
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
assert len(optimizer.costs) > 10
assert len(optimizer.sizes) > 10
assert path == optimizer.path
assert optimizer.best["flops"] == min(optimizer.costs)
assert path_info.largest_intermediate == optimizer.best["size"]
assert path_info.opt_cost == optimizer.best["flops"]
optimizer.parallel = True
assert optimizer._executor is not None
assert optimizer._executor is not pool
are_done = [f.running() or f.done() for f in optimizer._futures]
assert all(are_done)
def test_custom_path_optimizer() -> None:
np = pytest.importorskip("numpy")
class NaiveOptimizer(oe.paths.PathOptimizer):
def __call__(
self,
inputs: List[ArrayIndexType],
output: ArrayIndexType,
size_dict: Dict[str, int],
memory_limit: Optional[int] = None,
) -> PathType:
self.was_used = True
return [(0, 1)] * (len(inputs) - 1)
eq, shapes = rand_equation(5, 3, seed=42, d_max=3)
views = list(map(np.ones, shapes))
exp = oe.contract(eq, *views, optimize=False)
optimizer = NaiveOptimizer()
out = oe.contract(eq, *views, optimize=optimizer)
assert exp == out
assert optimizer.was_used
def test_custom_random_optimizer() -> None:
np = pytest.importorskip("numpy")
class NaiveRandomOptimizer(oe.path_random.RandomOptimizer):
@staticmethod
def random_path(
r: int, n: int, inputs: List[ArrayIndexType], output: ArrayIndexType, size_dict: Dict[str, int]
) -> Any:
"""Picks a completely random contraction order."""
np.random.seed(r)
ssa_path: List[TensorShapeType] = []
remaining = set(range(n))
while len(remaining) > 1:
i, j = np.random.choice(list(remaining), size=2, replace=False)
remaining.add(n + len(ssa_path))
remaining.remove(i)
remaining.remove(j)
ssa_path.append((i, j))
cost, size = oe.path_random.ssa_path_compute_cost(ssa_path, inputs, output, size_dict)
return ssa_path, cost, size
def setup(self, inputs: Any, output: Any, size_dict: Any) -> Any:
self.was_used = True
n = len(inputs)
trial_fn = self.random_path
trial_args = (n, inputs, output, size_dict)
return trial_fn, trial_args
eq, shapes = rand_equation(5, 3, seed=42, d_max=3)
views = list(map(np.ones, shapes))
exp = oe.contract(eq, *views, optimize=False)
optimizer = NaiveRandomOptimizer(max_repeats=16)
out = oe.contract(eq, *views, optimize=optimizer)
assert exp == out
assert optimizer.was_used
assert len(optimizer.costs) == 16
def test_optimizer_registration() -> None:
def custom_optimizer(
inputs: List[ArrayIndexType], output: ArrayIndexType, size_dict: Dict[str, int], memory_limit: Optional[int]
) -> PathType:
return [(0, 1)] * (len(inputs) - 1)
with pytest.raises(KeyError):
oe.paths.register_path_fn("optimal", custom_optimizer)
oe.paths.register_path_fn("custom", custom_optimizer)
assert "custom" in oe.paths._PATH_OPTIONS
eq = "ab,bc,cd"
shapes = [(2, 3), (3, 4), (4, 5)]
path, _ = oe.contract_path(eq, *shapes, shapes=True, optimize="custom") # type: ignore
assert path == [(0, 1), (0, 1)]
del oe.paths._PATH_OPTIONS["custom"]
def test_path_with_assumed_shapes() -> None:
path, _ = oe.contract_path("ab,bc,cd", [[5, 3]], [[2], [4]], [[3, 2]])
assert path == [(0, 1), (0, 1)]