2025-08-11 12:24:21 +08:00

33 lines
950 B
Python

"""Required functions for optimized contractions of numpy arrays using cupy."""
from opt_einsum.helpers import has_array_interface
from opt_einsum.sharing import to_backend_cache_wrap
__all__ = ["to_cupy", "build_expression", "evaluate_constants"]
@to_backend_cache_wrap
def to_cupy(array): # pragma: no cover
import cupy
if has_array_interface(array):
return cupy.asarray(array)
return array
def build_expression(_, expr): # pragma: no cover
"""Build a cupy function based on ``arrays`` and ``expr``."""
def cupy_contract(*arrays):
return expr._contract([to_cupy(x) for x in arrays], backend="cupy").get()
return cupy_contract
def evaluate_constants(const_arrays, expr): # pragma: no cover
"""Convert constant arguments to cupy arrays, and perform any possible
constant contractions.
"""
return expr(*[to_cupy(x) for x in const_arrays], backend="cupy", evaluate_constants=True)