535 lines
19 KiB
Python
535 lines
19 KiB
Python
"""
|
|
Tests the accuracy of the opt_einsum paths in addition to unit tests for
|
|
the various path helper functions.
|
|
"""
|
|
|
|
import itertools
|
|
from concurrent.futures import ProcessPoolExecutor
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import pytest
|
|
|
|
import opt_einsum as oe
|
|
from opt_einsum.testing import build_shapes, rand_equation
|
|
from opt_einsum.typing import ArrayIndexType, OptimizeKind, PathType, TensorShapeType
|
|
|
|
explicit_path_tests = {
|
|
"GEMM1": (
|
|
[set("abd"), set("ac"), set("bdc")],
|
|
set(""),
|
|
{"a": 1, "b": 2, "c": 3, "d": 4},
|
|
),
|
|
"Inner1": (
|
|
[set("abcd"), set("abc"), set("bc")],
|
|
set(""),
|
|
{"a": 5, "b": 2, "c": 3, "d": 4},
|
|
),
|
|
}
|
|
|
|
# note that these tests have no unique solution due to the chosen dimensions
|
|
path_edge_tests = [
|
|
["greedy", "eb,cb,fb->cef", ((0, 2), (0, 1))],
|
|
["branch-all", "eb,cb,fb->cef", ((0, 2), (0, 1))],
|
|
["branch-2", "eb,cb,fb->cef", ((0, 2), (0, 1))],
|
|
["optimal", "eb,cb,fb->cef", ((0, 2), (0, 1))],
|
|
["dp", "eb,cb,fb->cef", ((1, 2), (0, 1))],
|
|
["greedy", "dd,fb,be,cdb->cef", ((0, 3), (0, 1), (0, 1))],
|
|
["branch-all", "dd,fb,be,cdb->cef", ((0, 3), (0, 1), (0, 1))],
|
|
["branch-2", "dd,fb,be,cdb->cef", ((0, 3), (0, 1), (0, 1))],
|
|
["optimal", "dd,fb,be,cdb->cef", ((0, 3), (0, 1), (0, 1))],
|
|
["optimal", "dd,fb,be,cdb->cef", ((0, 3), (0, 1), (0, 1))],
|
|
["dp", "dd,fb,be,cdb->cef", ((0, 3), (0, 2), (0, 1))],
|
|
["greedy", "bca,cdb,dbf,afc->", ((1, 2), (0, 2), (0, 1))],
|
|
["branch-all", "bca,cdb,dbf,afc->", ((1, 2), (0, 2), (0, 1))],
|
|
["branch-2", "bca,cdb,dbf,afc->", ((1, 2), (0, 2), (0, 1))],
|
|
["optimal", "bca,cdb,dbf,afc->", ((1, 2), (0, 2), (0, 1))],
|
|
["dp", "bca,cdb,dbf,afc->", ((1, 2), (1, 2), (0, 1))],
|
|
["greedy", "dcc,fce,ea,dbf->ab", ((1, 2), (0, 1), (0, 1))],
|
|
["branch-all", "dcc,fce,ea,dbf->ab", ((1, 2), (0, 2), (0, 1))],
|
|
["branch-2", "dcc,fce,ea,dbf->ab", ((1, 2), (0, 2), (0, 1))],
|
|
["optimal", "dcc,fce,ea,dbf->ab", ((1, 2), (0, 2), (0, 1))],
|
|
["dp", "dcc,fce,ea,dbf->ab", ((1, 2), (0, 2), (0, 1))],
|
|
]
|
|
|
|
# note that these tests have no unique solution due to the chosen dimensions
|
|
path_scalar_tests = [
|
|
[
|
|
"a,->a",
|
|
1,
|
|
],
|
|
["ab,->ab", 1],
|
|
[",a,->a", 2],
|
|
[",,a,->a", 3],
|
|
[",,->", 2],
|
|
]
|
|
|
|
|
|
def check_path(test_output: PathType, benchmark: PathType, bypass: bool = False) -> bool:
|
|
if not isinstance(test_output, list):
|
|
return False
|
|
|
|
if len(test_output) != len(benchmark):
|
|
return False
|
|
|
|
ret = True
|
|
for pos in range(len(test_output)):
|
|
ret &= isinstance(test_output[pos], tuple)
|
|
ret &= test_output[pos] == list(benchmark)[pos]
|
|
return ret
|
|
|
|
|
|
def assert_contract_order(func: Any, test_data: Any, max_size: int, benchmark: PathType) -> None:
|
|
test_output = func(test_data[0], test_data[1], test_data[2], max_size)
|
|
assert check_path(test_output, benchmark)
|
|
|
|
|
|
def test_size_by_dict() -> None:
|
|
sizes_dict = {}
|
|
for ind, val in zip("abcdez", [2, 5, 9, 11, 13, 0]):
|
|
sizes_dict[ind] = val
|
|
|
|
path_func = oe.helpers.compute_size_by_dict
|
|
|
|
assert 1 == path_func("", sizes_dict)
|
|
assert 2 == path_func("a", sizes_dict)
|
|
assert 5 == path_func("b", sizes_dict)
|
|
|
|
assert 0 == path_func("z", sizes_dict)
|
|
assert 0 == path_func("az", sizes_dict)
|
|
assert 0 == path_func("zbc", sizes_dict)
|
|
|
|
assert 104 == path_func("aaae", sizes_dict)
|
|
assert 12870 == path_func("abcde", sizes_dict)
|
|
|
|
|
|
def test_flop_cost() -> None:
|
|
size_dict = {v: 10 for v in "abcdef"}
|
|
|
|
# Loop over an array
|
|
assert 10 == oe.helpers.flop_count("a", False, 1, size_dict)
|
|
|
|
# Hadamard product (*)
|
|
assert 10 == oe.helpers.flop_count("a", False, 2, size_dict)
|
|
assert 100 == oe.helpers.flop_count("ab", False, 2, size_dict)
|
|
|
|
# Inner product (+, *)
|
|
assert 20 == oe.helpers.flop_count("a", True, 2, size_dict)
|
|
assert 200 == oe.helpers.flop_count("ab", True, 2, size_dict)
|
|
|
|
# Inner product x3 (+, *, *)
|
|
assert 30 == oe.helpers.flop_count("a", True, 3, size_dict)
|
|
|
|
# GEMM
|
|
assert 2000 == oe.helpers.flop_count("abc", True, 2, size_dict)
|
|
|
|
|
|
def test_bad_path_option() -> None:
|
|
with pytest.raises(KeyError):
|
|
oe.contract("a,b,c", [1], [2], [3], optimize="optimall", shapes=True) # type: ignore
|
|
|
|
|
|
def test_explicit_path() -> None:
|
|
pytest.importorskip("numpy")
|
|
x = oe.contract("a,b,c", [1], [2], [3], optimize=[(1, 2), (0, 1)])
|
|
assert x.item() == 6
|
|
|
|
|
|
def test_path_optimal() -> None:
|
|
test_func = oe.paths.optimal
|
|
|
|
test_data = explicit_path_tests["GEMM1"]
|
|
assert_contract_order(test_func, test_data, 5000, [(0, 2), (0, 1)])
|
|
assert_contract_order(test_func, test_data, 0, [(0, 1, 2)])
|
|
|
|
|
|
def test_path_greedy() -> None:
|
|
test_func = oe.paths.greedy
|
|
|
|
test_data = explicit_path_tests["GEMM1"]
|
|
assert_contract_order(test_func, test_data, 5000, [(0, 2), (0, 1)])
|
|
assert_contract_order(test_func, test_data, 0, [(0, 1, 2)])
|
|
|
|
|
|
def test_memory_paths() -> None:
|
|
expression = "abc,bdef,fghj,cem,mhk,ljk->adgl"
|
|
|
|
views = build_shapes(expression)
|
|
|
|
# Test tiny memory limit
|
|
path_ret = oe.contract_path(expression, *views, optimize="optimal", memory_limit=5, shapes=True)
|
|
assert check_path(path_ret[0], [(0, 1, 2, 3, 4, 5)])
|
|
|
|
path_ret = oe.contract_path(expression, *views, optimize="greedy", memory_limit=5, shapes=True)
|
|
assert check_path(path_ret[0], [(0, 1, 2, 3, 4, 5)])
|
|
|
|
# Check the possibilities, greedy is capped
|
|
path_ret = oe.contract_path(expression, *views, optimize="optimal", memory_limit=-1, shapes=True)
|
|
assert check_path(path_ret[0], [(0, 3), (0, 4), (0, 2), (0, 2), (0, 1)])
|
|
|
|
path_ret = oe.contract_path(expression, *views, optimize="greedy", memory_limit=-1, shapes=True)
|
|
assert check_path(path_ret[0], [(0, 3), (0, 4), (0, 2), (0, 2), (0, 1)])
|
|
|
|
|
|
@pytest.mark.parametrize("alg,expression,order", path_edge_tests)
|
|
def test_path_edge_cases(alg: OptimizeKind, expression: str, order: PathType) -> None:
|
|
views = build_shapes(expression)
|
|
|
|
# Test tiny memory limit
|
|
path_ret = oe.contract_path(expression, *views, optimize=alg, shapes=True)
|
|
assert check_path(path_ret[0], order)
|
|
|
|
|
|
@pytest.mark.parametrize("expression,order", path_scalar_tests)
|
|
@pytest.mark.parametrize("alg", oe.paths._PATH_OPTIONS)
|
|
def test_path_scalar_cases(alg: OptimizeKind, expression: str, order: PathType) -> None:
|
|
views = build_shapes(expression)
|
|
|
|
# Test tiny memory limit
|
|
path_ret = oe.contract_path(expression, *views, optimize=alg, shapes=True)
|
|
# print(path_ret[0])
|
|
assert len(path_ret[0]) == order
|
|
|
|
|
|
def test_optimal_edge_cases() -> None:
|
|
# Edge test5
|
|
expression = "a,ac,ab,ad,cd,bd,bc->"
|
|
edge_test4 = build_shapes(expression, dimension_dict={"a": 20, "b": 20, "c": 20, "d": 20})
|
|
path, _ = oe.contract_path(expression, *edge_test4, optimize="greedy", memory_limit="max_input", shapes=True)
|
|
assert check_path(path, [(0, 1), (0, 1, 2, 3, 4, 5)])
|
|
|
|
path, _ = oe.contract_path(expression, *edge_test4, optimize="optimal", memory_limit="max_input", shapes=True)
|
|
assert check_path(path, [(0, 1), (0, 1, 2, 3, 4, 5)])
|
|
|
|
|
|
def test_greedy_edge_cases() -> None:
|
|
expression = "abc,cfd,dbe,efa"
|
|
dim_dict = {k: 20 for k in expression.replace(",", "")}
|
|
tensors = build_shapes(expression, dimension_dict=dim_dict)
|
|
|
|
path, _ = oe.contract_path(expression, *tensors, optimize="greedy", memory_limit="max_input", shapes=True)
|
|
assert check_path(path, [(0, 1, 2, 3)])
|
|
|
|
path, _ = oe.contract_path(expression, *tensors, optimize="greedy", memory_limit=-1, shapes=True)
|
|
assert check_path(path, [(0, 1), (0, 2), (0, 1)])
|
|
|
|
|
|
def test_dp_edge_cases_dimension_1() -> None:
|
|
eq = "nlp,nlq,pl->n"
|
|
shapes = [(1, 1, 1), (1, 1, 1), (1, 1)]
|
|
info = oe.contract_path(eq, *shapes, shapes=True, optimize="dp")[1]
|
|
assert max(info.scale_list) == 3
|
|
|
|
|
|
def test_dp_edge_cases_all_singlet_indices() -> None:
|
|
eq = "a,bcd,efg->"
|
|
shapes = [(2,), (2, 2, 2), (2, 2, 2)]
|
|
info = oe.contract_path(eq, *shapes, shapes=True, optimize="dp")[1]
|
|
assert max(info.scale_list) == 3
|
|
|
|
|
|
def test_custom_dp_can_optimize_for_outer_products() -> None:
|
|
eq = "a,b,abc->c"
|
|
|
|
da, db, dc = 2, 2, 3
|
|
shapes = [(da,), (db,), (da, db, dc)]
|
|
|
|
opt1 = oe.DynamicProgramming(search_outer=False)
|
|
opt2 = oe.DynamicProgramming(search_outer=True)
|
|
|
|
info1 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt1)[1]
|
|
info2 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt2)[1]
|
|
|
|
assert info2.opt_cost < info1.opt_cost
|
|
|
|
|
|
def test_custom_dp_can_optimize_for_size() -> None:
|
|
eq, shapes = rand_equation(10, 4, seed=43)
|
|
|
|
opt1 = oe.DynamicProgramming(minimize="flops")
|
|
opt2 = oe.DynamicProgramming(minimize="size")
|
|
|
|
info1 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt1)[1]
|
|
info2 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt2)[1]
|
|
|
|
assert info1.opt_cost < info2.opt_cost
|
|
assert info1.largest_intermediate > info2.largest_intermediate
|
|
|
|
|
|
def test_custom_dp_can_set_cost_cap() -> None:
|
|
eq, shapes = rand_equation(5, 3, seed=42)
|
|
opt1 = oe.DynamicProgramming(cost_cap=True)
|
|
opt2 = oe.DynamicProgramming(cost_cap=False)
|
|
opt3 = oe.DynamicProgramming(cost_cap=100)
|
|
info1 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt1)[1]
|
|
info2 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt2)[1]
|
|
info3 = oe.contract_path(eq, *shapes, shapes=True, optimize=opt3)[1]
|
|
assert info1.opt_cost == info2.opt_cost == info3.opt_cost
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"minimize,cost,width,path",
|
|
[
|
|
("flops", 663054, 18900, [(4, 5), (2, 5), (2, 7), (5, 6), (1, 5), (1, 4), (0, 3), (0, 2), (0, 1)]),
|
|
("size", 1114440, 2016, [(2, 7), (3, 8), (3, 7), (2, 6), (1, 5), (1, 4), (1, 3), (1, 2), (0, 1)]),
|
|
("write", 983790, 2016, [(0, 8), (3, 4), (1, 4), (5, 6), (1, 5), (0, 4), (0, 3), (1, 2), (0, 1)]),
|
|
("combo", 973518, 2016, [(4, 5), (2, 5), (6, 7), (2, 6), (1, 5), (1, 4), (0, 3), (0, 2), (0, 1)]),
|
|
("limit", 983832, 2016, [(2, 7), (3, 4), (0, 4), (3, 6), (2, 5), (0, 4), (0, 3), (1, 2), (0, 1)]),
|
|
("combo-256", 983790, 2016, [(0, 8), (3, 4), (1, 4), (5, 6), (1, 5), (0, 4), (0, 3), (1, 2), (0, 1)]),
|
|
("limit-256", 983832, 2016, [(2, 7), (3, 4), (0, 4), (3, 6), (2, 5), (0, 4), (0, 3), (1, 2), (0, 1)]),
|
|
],
|
|
)
|
|
def test_custom_dp_can_set_minimize(minimize: str, cost: int, width: int, path: PathType) -> None:
|
|
eq, shapes = rand_equation(10, 4, seed=43)
|
|
opt = oe.DynamicProgramming(minimize=minimize)
|
|
info = oe.contract_path(eq, *shapes, shapes=True, optimize=opt)[1]
|
|
assert info.path == path
|
|
assert info.opt_cost == cost
|
|
assert info.largest_intermediate == width
|
|
|
|
|
|
def test_dp_errors_when_no_contractions_found() -> None:
|
|
eq, shapes = rand_equation(10, 3, seed=42)
|
|
|
|
# first get the actual minimum cost
|
|
opt = oe.DynamicProgramming(minimize="size")
|
|
_, info = oe.contract_path(eq, *shapes, shapes=True, optimize=opt)
|
|
mincost = info.largest_intermediate
|
|
|
|
# check we can still find it without minimizing size explicitly
|
|
oe.contract_path(eq, *shapes, shapes=True, memory_limit=mincost, optimize="dp")
|
|
|
|
# but check just below this threshold raises
|
|
with pytest.raises(RuntimeError):
|
|
oe.contract_path(eq, *shapes, shapes=True, memory_limit=mincost - 1, optimize="dp")
|
|
|
|
|
|
@pytest.mark.parametrize("optimize", ["greedy", "branch-2", "branch-all", "optimal", "dp"])
|
|
def test_can_optimize_outer_products(optimize: OptimizeKind) -> None:
|
|
a, b, c = ((10, 10) for _ in range(3))
|
|
d = (10, 2)
|
|
|
|
assert oe.contract_path("ab,cd,ef,fg", a, b, c, d, optimize=optimize, shapes=True)[0] == [
|
|
(2, 3),
|
|
(0, 2),
|
|
(0, 1),
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("num_symbols", [2, 3, 26, 26 + 26, 256 - 140, 300])
|
|
def test_large_path(num_symbols: int) -> None:
|
|
symbols = "".join(oe.get_symbol(i) for i in range(num_symbols))
|
|
dimension_dict = dict(zip(symbols, itertools.cycle([2, 3, 4])))
|
|
expression = ",".join(symbols[t : t + 2] for t in range(num_symbols - 1))
|
|
tensors = build_shapes(expression, dimension_dict=dimension_dict)
|
|
|
|
# Check that path construction does not crash
|
|
oe.contract_path(expression, *tensors, optimize="greedy", shapes=True)
|
|
|
|
|
|
def test_custom_random_greedy() -> None:
|
|
np = pytest.importorskip("numpy")
|
|
|
|
eq, shapes = rand_equation(10, 4, seed=42)
|
|
views = list(map(np.ones, shapes))
|
|
|
|
with pytest.raises(ValueError):
|
|
oe.RandomGreedy(minimize="something")
|
|
|
|
optimizer = oe.RandomGreedy(max_repeats=10, minimize="flops")
|
|
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
|
|
|
|
assert len(optimizer.costs) == 10
|
|
assert len(optimizer.sizes) == 10
|
|
|
|
assert path == optimizer.path
|
|
assert optimizer.best["flops"] == min(optimizer.costs)
|
|
assert path_info.largest_intermediate == optimizer.best["size"]
|
|
assert path_info.opt_cost == optimizer.best["flops"]
|
|
|
|
# check can change settings and run again
|
|
optimizer.temperature = 0.0
|
|
optimizer.max_repeats = 6
|
|
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
|
|
|
|
assert len(optimizer.costs) == 16
|
|
assert len(optimizer.sizes) == 16
|
|
|
|
assert path == optimizer.path
|
|
assert optimizer.best["size"] == min(optimizer.sizes)
|
|
assert path_info.largest_intermediate == optimizer.best["size"]
|
|
assert path_info.opt_cost == optimizer.best["flops"]
|
|
|
|
# check error if we try and reuse the optimizer on a different expression
|
|
eq, shapes = rand_equation(10, 4, seed=41)
|
|
views = list(map(np.ones, shapes))
|
|
with pytest.raises(ValueError):
|
|
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
|
|
|
|
|
|
def test_custom_branchbound() -> None:
|
|
np = pytest.importorskip("numpy")
|
|
|
|
eq, shapes = rand_equation(8, 4, seed=42)
|
|
views = list(map(np.ones, shapes))
|
|
optimizer = oe.BranchBound(nbranch=2, cutoff_flops_factor=10, minimize="size")
|
|
|
|
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
|
|
|
|
assert path == optimizer.path
|
|
assert path_info.largest_intermediate == optimizer.best["size"]
|
|
assert path_info.opt_cost == optimizer.best["flops"]
|
|
|
|
# tweak settings and run again
|
|
optimizer.nbranch = 3
|
|
optimizer.cutoff_flops_factor = 4
|
|
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
|
|
|
|
assert path == optimizer.path
|
|
assert path_info.largest_intermediate == optimizer.best["size"]
|
|
assert path_info.opt_cost == optimizer.best["flops"]
|
|
|
|
# check error if we try and reuse the optimizer on a different expression
|
|
eq, shapes = rand_equation(8, 4, seed=41)
|
|
views = list(map(np.ones, shapes))
|
|
with pytest.raises(ValueError):
|
|
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
|
|
|
|
|
|
def test_branchbound_validation() -> None:
|
|
with pytest.raises(ValueError):
|
|
oe.BranchBound(nbranch=0)
|
|
|
|
|
|
def test_parallel_random_greedy() -> None:
|
|
np = pytest.importorskip("numpy")
|
|
|
|
pool = ProcessPoolExecutor(2)
|
|
|
|
eq, shapes = rand_equation(10, 4, seed=42)
|
|
views = list(map(np.ones, shapes))
|
|
|
|
optimizer = oe.RandomGreedy(max_repeats=10, parallel=pool)
|
|
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
|
|
|
|
assert len(optimizer.costs) == 10
|
|
assert len(optimizer.sizes) == 10
|
|
|
|
assert path == optimizer.path
|
|
assert optimizer.parallel is pool
|
|
assert optimizer._executor is pool
|
|
assert optimizer.best["flops"] == min(optimizer.costs)
|
|
assert path_info.largest_intermediate == optimizer.best["size"]
|
|
assert path_info.opt_cost == optimizer.best["flops"]
|
|
|
|
# now switch to max time algorithm
|
|
optimizer.max_repeats = int(1e6)
|
|
optimizer.max_time = 0.2
|
|
optimizer.parallel = 2
|
|
|
|
path, path_info = oe.contract_path(eq, *views, optimize=optimizer)
|
|
|
|
assert len(optimizer.costs) > 10
|
|
assert len(optimizer.sizes) > 10
|
|
|
|
assert path == optimizer.path
|
|
assert optimizer.best["flops"] == min(optimizer.costs)
|
|
assert path_info.largest_intermediate == optimizer.best["size"]
|
|
assert path_info.opt_cost == optimizer.best["flops"]
|
|
|
|
optimizer.parallel = True
|
|
assert optimizer._executor is not None
|
|
assert optimizer._executor is not pool
|
|
|
|
are_done = [f.running() or f.done() for f in optimizer._futures]
|
|
assert all(are_done)
|
|
|
|
|
|
def test_custom_path_optimizer() -> None:
|
|
np = pytest.importorskip("numpy")
|
|
|
|
class NaiveOptimizer(oe.paths.PathOptimizer):
|
|
def __call__(
|
|
self,
|
|
inputs: List[ArrayIndexType],
|
|
output: ArrayIndexType,
|
|
size_dict: Dict[str, int],
|
|
memory_limit: Optional[int] = None,
|
|
) -> PathType:
|
|
self.was_used = True
|
|
return [(0, 1)] * (len(inputs) - 1)
|
|
|
|
eq, shapes = rand_equation(5, 3, seed=42, d_max=3)
|
|
views = list(map(np.ones, shapes))
|
|
|
|
exp = oe.contract(eq, *views, optimize=False)
|
|
|
|
optimizer = NaiveOptimizer()
|
|
out = oe.contract(eq, *views, optimize=optimizer)
|
|
assert exp == out
|
|
assert optimizer.was_used
|
|
|
|
|
|
def test_custom_random_optimizer() -> None:
|
|
np = pytest.importorskip("numpy")
|
|
|
|
class NaiveRandomOptimizer(oe.path_random.RandomOptimizer):
|
|
@staticmethod
|
|
def random_path(
|
|
r: int, n: int, inputs: List[ArrayIndexType], output: ArrayIndexType, size_dict: Dict[str, int]
|
|
) -> Any:
|
|
"""Picks a completely random contraction order."""
|
|
np.random.seed(r)
|
|
ssa_path: List[TensorShapeType] = []
|
|
remaining = set(range(n))
|
|
while len(remaining) > 1:
|
|
i, j = np.random.choice(list(remaining), size=2, replace=False)
|
|
remaining.add(n + len(ssa_path))
|
|
remaining.remove(i)
|
|
remaining.remove(j)
|
|
ssa_path.append((i, j))
|
|
cost, size = oe.path_random.ssa_path_compute_cost(ssa_path, inputs, output, size_dict)
|
|
return ssa_path, cost, size
|
|
|
|
def setup(self, inputs: Any, output: Any, size_dict: Any) -> Any:
|
|
self.was_used = True
|
|
n = len(inputs)
|
|
trial_fn = self.random_path
|
|
trial_args = (n, inputs, output, size_dict)
|
|
return trial_fn, trial_args
|
|
|
|
eq, shapes = rand_equation(5, 3, seed=42, d_max=3)
|
|
views = list(map(np.ones, shapes))
|
|
|
|
exp = oe.contract(eq, *views, optimize=False)
|
|
|
|
optimizer = NaiveRandomOptimizer(max_repeats=16)
|
|
out = oe.contract(eq, *views, optimize=optimizer)
|
|
assert exp == out
|
|
assert optimizer.was_used
|
|
|
|
assert len(optimizer.costs) == 16
|
|
|
|
|
|
def test_optimizer_registration() -> None:
|
|
def custom_optimizer(
|
|
inputs: List[ArrayIndexType], output: ArrayIndexType, size_dict: Dict[str, int], memory_limit: Optional[int]
|
|
) -> PathType:
|
|
return [(0, 1)] * (len(inputs) - 1)
|
|
|
|
with pytest.raises(KeyError):
|
|
oe.paths.register_path_fn("optimal", custom_optimizer)
|
|
|
|
oe.paths.register_path_fn("custom", custom_optimizer)
|
|
assert "custom" in oe.paths._PATH_OPTIONS
|
|
|
|
eq = "ab,bc,cd"
|
|
shapes = [(2, 3), (3, 4), (4, 5)]
|
|
path, _ = oe.contract_path(eq, *shapes, shapes=True, optimize="custom") # type: ignore
|
|
assert path == [(0, 1), (0, 1)]
|
|
del oe.paths._PATH_OPTIONS["custom"]
|
|
|
|
|
|
def test_path_with_assumed_shapes() -> None:
|
|
path, _ = oe.contract_path("ab,bc,cd", [[5, 3]], [[2], [4]], [[3, 2]])
|
|
assert path == [(0, 1), (0, 1)]
|