280 lines
8.4 KiB
Python
280 lines
8.4 KiB
Python
"""
|
|
Tets a series of opt_einsum contraction paths to ensure the results are the same for different paths
|
|
"""
|
|
|
|
from typing import Any, List
|
|
|
|
import pytest
|
|
|
|
from opt_einsum import contract, contract_expression, contract_path
|
|
from opt_einsum.paths import _PATH_OPTIONS, linear_to_ssa, ssa_to_linear
|
|
from opt_einsum.testing import build_views, rand_equation
|
|
from opt_einsum.typing import OptimizeKind
|
|
|
|
# NumPy is required for the majority of this file
|
|
np = pytest.importorskip("numpy")
|
|
|
|
|
|
tests = [
|
|
# Test scalar-like operations
|
|
"a,->a",
|
|
"ab,->ab",
|
|
",ab,->ab",
|
|
",,->",
|
|
# Test hadamard-like products
|
|
"a,ab,abc->abc",
|
|
"a,b,ab->ab",
|
|
# Test index-transformations
|
|
"ea,fb,gc,hd,abcd->efgh",
|
|
"ea,fb,abcd,gc,hd->efgh",
|
|
"abcd,ea,fb,gc,hd->efgh",
|
|
# Test complex contractions
|
|
"acdf,jbje,gihb,hfac,gfac,gifabc,hfac",
|
|
"acdf,jbje,gihb,hfac,gfac,gifabc,hfac",
|
|
"cd,bdhe,aidb,hgca,gc,hgibcd,hgac",
|
|
"abhe,hidj,jgba,hiab,gab",
|
|
"bde,cdh,agdb,hica,ibd,hgicd,hiac",
|
|
"chd,bde,agbc,hiad,hgc,hgi,hiad",
|
|
"chd,bde,agbc,hiad,bdi,cgh,agdb",
|
|
"bdhe,acad,hiab,agac,hibd",
|
|
# Test collapse
|
|
"ab,ab,c->",
|
|
"ab,ab,c->c",
|
|
"ab,ab,cd,cd->",
|
|
"ab,ab,cd,cd->ac",
|
|
"ab,ab,cd,cd->cd",
|
|
"ab,ab,cd,cd,ef,ef->",
|
|
# Test outer prodcuts
|
|
"ab,cd,ef->abcdef",
|
|
"ab,cd,ef->acdf",
|
|
"ab,cd,de->abcde",
|
|
"ab,cd,de->be",
|
|
"ab,bcd,cd->abcd",
|
|
"ab,bcd,cd->abd",
|
|
# Random test cases that have previously failed
|
|
"eb,cb,fb->cef",
|
|
"dd,fb,be,cdb->cef",
|
|
"bca,cdb,dbf,afc->",
|
|
"dcc,fce,ea,dbf->ab",
|
|
"fdf,cdd,ccd,afe->ae",
|
|
"abcd,ad",
|
|
"ed,fcd,ff,bcf->be",
|
|
"baa,dcf,af,cde->be",
|
|
"bd,db,eac->ace",
|
|
"fff,fae,bef,def->abd",
|
|
"efc,dbc,acf,fd->abe",
|
|
# Inner products
|
|
"ab,ab",
|
|
"ab,ba",
|
|
"abc,abc",
|
|
"abc,bac",
|
|
"abc,cba",
|
|
# GEMM test cases
|
|
"ab,bc",
|
|
"ab,cb",
|
|
"ba,bc",
|
|
"ba,cb",
|
|
"abcd,cd",
|
|
"abcd,ab",
|
|
"abcd,cdef",
|
|
"abcd,cdef->feba",
|
|
"abcd,efdc",
|
|
# Inner than dot
|
|
"aab,bc->ac",
|
|
"ab,bcc->ac",
|
|
"aab,bcc->ac",
|
|
"baa,bcc->ac",
|
|
"aab,ccb->ac",
|
|
# Randomly build test caes
|
|
"aab,fa,df,ecc->bde",
|
|
"ecb,fef,bad,ed->ac",
|
|
"bcf,bbb,fbf,fc->",
|
|
"bb,ff,be->e",
|
|
"bcb,bb,fc,fff->",
|
|
"fbb,dfd,fc,fc->",
|
|
"afd,ba,cc,dc->bf",
|
|
"adb,bc,fa,cfc->d",
|
|
"bbd,bda,fc,db->acf",
|
|
"dba,ead,cad->bce",
|
|
"aef,fbc,dca->bde",
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize("optimize", (True, False, None))
|
|
def test_contract_plain_types(optimize: OptimizeKind) -> None:
|
|
expr = "ij,jk,kl->il"
|
|
ops = [np.random.rand(2, 2), np.random.rand(2, 2), np.random.rand(2, 2)]
|
|
|
|
path = contract_path(expr, *ops, optimize=optimize)
|
|
assert len(path) == 2
|
|
|
|
result = contract(expr, *ops, optimize=optimize)
|
|
assert result.shape == (2, 2)
|
|
|
|
|
|
@pytest.mark.parametrize("string", tests)
|
|
@pytest.mark.parametrize("optimize", _PATH_OPTIONS)
|
|
def test_compare(optimize: OptimizeKind, string: str) -> None:
|
|
views = build_views(string)
|
|
|
|
ein = contract(string, *views, optimize=False, use_blas=False)
|
|
opt = contract(string, *views, optimize=optimize, use_blas=False)
|
|
assert np.allclose(ein, opt)
|
|
|
|
|
|
@pytest.mark.parametrize("string", tests)
|
|
def test_drop_in_replacement(string: str) -> None:
|
|
views = build_views(string)
|
|
opt = contract(string, *views)
|
|
assert np.allclose(opt, np.einsum(string, *views))
|
|
|
|
|
|
@pytest.mark.parametrize("string", tests)
|
|
@pytest.mark.parametrize("optimize", _PATH_OPTIONS)
|
|
def test_compare_greek(optimize: OptimizeKind, string: str) -> None:
|
|
views = build_views(string)
|
|
|
|
ein = contract(string, *views, optimize=False, use_blas=False)
|
|
|
|
# convert to greek
|
|
string = "".join(chr(ord(c) + 848) if c not in ",->." else c for c in string)
|
|
|
|
opt = contract(string, *views, optimize=optimize, use_blas=False)
|
|
assert np.allclose(ein, opt)
|
|
|
|
|
|
@pytest.mark.parametrize("string", tests)
|
|
@pytest.mark.parametrize("optimize", _PATH_OPTIONS)
|
|
def test_compare_blas(optimize: OptimizeKind, string: str) -> None:
|
|
views = build_views(string)
|
|
|
|
ein = contract(string, *views, optimize=False)
|
|
opt = contract(string, *views, optimize=optimize)
|
|
assert np.allclose(ein, opt)
|
|
|
|
|
|
@pytest.mark.parametrize("string", tests)
|
|
@pytest.mark.parametrize("optimize", _PATH_OPTIONS)
|
|
def test_compare_blas_greek(optimize: OptimizeKind, string: str) -> None:
|
|
views = build_views(string)
|
|
|
|
ein = contract(string, *views, optimize=False)
|
|
|
|
# convert to greek
|
|
string = "".join(chr(ord(c) + 848) if c not in ",->." else c for c in string)
|
|
|
|
opt = contract(string, *views, optimize=optimize)
|
|
assert np.allclose(ein, opt)
|
|
|
|
|
|
def test_some_non_alphabet_maintains_order() -> None:
|
|
# 'c beta a' should automatically go to -> 'a c beta'
|
|
string = "c" + chr(ord("b") + 848) + "a"
|
|
# but beta will be temporarily replaced with 'b' for which 'cba->abc'
|
|
# so check manual output kicks in:
|
|
x = np.random.rand(2, 3, 4)
|
|
assert np.allclose(contract(string, x), contract("cxa", x))
|
|
|
|
|
|
def test_printing():
|
|
string = "bbd,bda,fc,db->acf"
|
|
views = build_views(string)
|
|
|
|
ein = contract_path(string, *views)
|
|
assert len(str(ein[1])) == 728
|
|
|
|
|
|
@pytest.mark.parametrize("string", tests)
|
|
@pytest.mark.parametrize("optimize", _PATH_OPTIONS)
|
|
@pytest.mark.parametrize("use_blas", [False, True])
|
|
@pytest.mark.parametrize("out_spec", [False, True])
|
|
def test_contract_expressions(string: str, optimize: OptimizeKind, use_blas: bool, out_spec: bool) -> None:
|
|
views = build_views(string)
|
|
shapes = [view.shape if hasattr(view, "shape") else () for view in views]
|
|
expected = contract(string, *views, optimize=False, use_blas=False)
|
|
|
|
expr = contract_expression(string, *shapes, optimize=optimize, use_blas=use_blas)
|
|
|
|
if out_spec and ("->" in string) and (string[-2:] != "->"):
|
|
(out,) = build_views(string.split("->")[1])
|
|
expr(*views, out=out)
|
|
else:
|
|
out = expr(*views)
|
|
|
|
assert np.allclose(out, expected)
|
|
|
|
# check representations
|
|
assert string in expr.__repr__()
|
|
assert string in expr.__str__()
|
|
|
|
|
|
def test_contract_expression_interleaved_input() -> None:
|
|
x, y, z = (np.random.randn(2, 2) for _ in "xyz")
|
|
expected = np.einsum(x, [0, 1], y, [1, 2], z, [2, 3], [3, 0])
|
|
xshp, yshp, zshp = ((2, 2) for _ in "xyz")
|
|
expr = contract_expression(xshp, [0, 1], yshp, [1, 2], zshp, [2, 3], [3, 0])
|
|
out = expr(x, y, z)
|
|
assert np.allclose(out, expected)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"string,constants",
|
|
[
|
|
("hbc,bdef,cdkj,ji,ikeh,lfo", [1, 2, 3, 4]),
|
|
("bdef,cdkj,ji,ikeh,hbc,lfo", [0, 1, 2, 3]),
|
|
("hbc,bdef,cdkj,ji,ikeh,lfo", [1, 2, 3, 4]),
|
|
("hbc,bdef,cdkj,ji,ikeh,lfo", [1, 2, 3, 4]),
|
|
("ijab,acd,bce,df,ef->ji", [1, 2, 3, 4]),
|
|
("ab,cd,ad,cb", [1, 3]),
|
|
("ab,bc,cd", [0, 1]),
|
|
],
|
|
)
|
|
def test_contract_expression_with_constants(string: str, constants: List[int]) -> None:
|
|
views = build_views(string)
|
|
expected = contract(string, *views, optimize=False, use_blas=False)
|
|
|
|
shapes = [view.shape if hasattr(view, "shape") else () for view in views]
|
|
|
|
expr_args: List[Any] = []
|
|
ctrc_args = []
|
|
for i, (shape, view) in enumerate(zip(shapes, views)):
|
|
if i in constants:
|
|
expr_args.append(view)
|
|
else:
|
|
expr_args.append(shape)
|
|
ctrc_args.append(view)
|
|
|
|
expr = contract_expression(string, *expr_args, constants=constants)
|
|
out = expr(*ctrc_args)
|
|
assert np.allclose(expected, out)
|
|
|
|
|
|
@pytest.mark.parametrize("optimize", ["greedy", "optimal"])
|
|
@pytest.mark.parametrize("n", [4, 5])
|
|
@pytest.mark.parametrize("reg", [2, 3])
|
|
@pytest.mark.parametrize("n_out", [0, 2, 4])
|
|
@pytest.mark.parametrize("global_dim", [False, True])
|
|
def test_rand_equation(optimize: OptimizeKind, n: int, reg: int, n_out: int, global_dim: bool) -> None:
|
|
eq, _, size_dict = rand_equation(n, reg, n_out, d_min=2, d_max=5, seed=42, return_size_dict=True)
|
|
views = build_views(eq, size_dict)
|
|
|
|
expected = contract(eq, *views, optimize=False)
|
|
actual = contract(eq, *views, optimize=optimize)
|
|
|
|
assert np.allclose(expected, actual)
|
|
|
|
|
|
@pytest.mark.parametrize("equation", tests)
|
|
def test_linear_vs_ssa(equation: str) -> None:
|
|
views = build_views(equation)
|
|
linear_path, _ = contract_path(equation, *views)
|
|
ssa_path = linear_to_ssa(linear_path)
|
|
linear_path2 = ssa_to_linear(ssa_path)
|
|
assert linear_path2 == linear_path
|
|
|
|
|
|
def test_contract_path_supply_shapes() -> None:
|
|
eq = "ab,bc,cd"
|
|
shps = [(2, 3), (3, 4), (4, 5)]
|
|
contract_path(eq, *shps, shapes=True)
|