152 lines
4.2 KiB
Python
152 lines
4.2 KiB
Python
"""Contains helper functions for opt_einsum testing scripts."""
|
|
|
|
from typing import Any, Collection, Dict, FrozenSet, Iterable, List, Tuple, overload
|
|
|
|
from opt_einsum.typing import ArrayIndexType, ArrayType
|
|
|
|
__all__ = ["compute_size_by_dict", "find_contraction", "flop_count"]
|
|
|
|
_valid_chars = "abcdefghijklmopqABC"
|
|
_sizes = [2, 3, 4, 5, 4, 3, 2, 6, 5, 4, 3, 2, 5, 7, 4, 3, 2, 3, 4]
|
|
_default_dim_dict = dict(zip(_valid_chars, _sizes))
|
|
|
|
|
|
@overload
|
|
def compute_size_by_dict(indices: Iterable[int], idx_dict: List[int]) -> int: ...
|
|
|
|
|
|
@overload
|
|
def compute_size_by_dict(indices: Collection[str], idx_dict: Dict[str, int]) -> int: ...
|
|
|
|
|
|
def compute_size_by_dict(indices: Any, idx_dict: Any) -> int:
|
|
"""Computes the product of the elements in indices based on the dictionary
|
|
idx_dict.
|
|
|
|
Parameters
|
|
----------
|
|
indices : iterable
|
|
Indices to base the product on.
|
|
idx_dict : dictionary
|
|
Dictionary of index _sizes
|
|
|
|
Returns:
|
|
-------
|
|
ret : int
|
|
The resulting product.
|
|
|
|
Examples:
|
|
--------
|
|
>>> compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5})
|
|
90
|
|
|
|
"""
|
|
ret = 1
|
|
for i in indices: # lgtm [py/iteration-string-and-sequence]
|
|
ret *= idx_dict[i]
|
|
return ret
|
|
|
|
|
|
def find_contraction(
|
|
positions: Collection[int],
|
|
input_sets: List[ArrayIndexType],
|
|
output_set: ArrayIndexType,
|
|
) -> Tuple[FrozenSet[str], List[ArrayIndexType], ArrayIndexType, ArrayIndexType]:
|
|
"""Finds the contraction for a given set of input and output sets.
|
|
|
|
Parameters
|
|
----------
|
|
positions : iterable
|
|
Integer positions of terms used in the contraction.
|
|
input_sets : list
|
|
List of sets that represent the lhs side of the einsum subscript
|
|
output_set : set
|
|
Set that represents the rhs side of the overall einsum subscript
|
|
|
|
Returns:
|
|
-------
|
|
new_result : set
|
|
The indices of the resulting contraction
|
|
remaining : list
|
|
List of sets that have not been contracted, the new set is appended to
|
|
the end of this list
|
|
idx_removed : set
|
|
Indices removed from the entire contraction
|
|
idx_contraction : set
|
|
The indices used in the current contraction
|
|
|
|
Examples:
|
|
--------
|
|
# A simple dot product test case
|
|
>>> pos = (0, 1)
|
|
>>> isets = [set('ab'), set('bc')]
|
|
>>> oset = set('ac')
|
|
>>> find_contraction(pos, isets, oset)
|
|
({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'})
|
|
|
|
# A more complex case with additional terms in the contraction
|
|
>>> pos = (0, 2)
|
|
>>> isets = [set('abd'), set('ac'), set('bdc')]
|
|
>>> oset = set('ac')
|
|
>>> find_contraction(pos, isets, oset)
|
|
({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'})
|
|
"""
|
|
remaining = list(input_sets)
|
|
inputs = (remaining.pop(i) for i in sorted(positions, reverse=True))
|
|
idx_contract = frozenset.union(*inputs)
|
|
idx_remain = output_set.union(*remaining)
|
|
|
|
new_result = idx_remain & idx_contract
|
|
idx_removed = idx_contract - new_result
|
|
remaining.append(new_result)
|
|
|
|
return new_result, remaining, idx_removed, idx_contract
|
|
|
|
|
|
def flop_count(
|
|
idx_contraction: Collection[str],
|
|
inner: bool,
|
|
num_terms: int,
|
|
size_dictionary: Dict[str, int],
|
|
) -> int:
|
|
"""Computes the number of FLOPS in the contraction.
|
|
|
|
Parameters
|
|
----------
|
|
idx_contraction : iterable
|
|
The indices involved in the contraction
|
|
inner : bool
|
|
Does this contraction require an inner product?
|
|
num_terms : int
|
|
The number of terms in a contraction
|
|
size_dictionary : dict
|
|
The size of each of the indices in idx_contraction
|
|
|
|
Returns:
|
|
-------
|
|
flop_count : int
|
|
The total number of FLOPS required for the contraction.
|
|
|
|
Examples:
|
|
--------
|
|
>>> flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5})
|
|
30
|
|
|
|
>>> flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5})
|
|
60
|
|
|
|
"""
|
|
overall_size = compute_size_by_dict(idx_contraction, size_dictionary)
|
|
op_factor = max(1, num_terms - 1)
|
|
if inner:
|
|
op_factor += 1
|
|
|
|
return overall_size * op_factor
|
|
|
|
|
|
def has_array_interface(array: ArrayType) -> ArrayType:
|
|
if hasattr(array, "__array_interface__"):
|
|
return True
|
|
else:
|
|
return False
|