123 lines
3.4 KiB
Python
123 lines
3.4 KiB
Python
"""Determines if a contraction can use BLAS or not."""
|
|
|
|
from typing import List, Sequence, Tuple, Union
|
|
|
|
from opt_einsum.typing import ArrayIndexType
|
|
|
|
__all__ = ["can_blas"]
|
|
|
|
|
|
def can_blas(
|
|
inputs: List[str],
|
|
result: str,
|
|
idx_removed: ArrayIndexType,
|
|
shapes: Union[Sequence[Tuple[int]], None] = None,
|
|
) -> Union[str, bool]:
|
|
"""Checks if we can use a BLAS call.
|
|
|
|
Parameters
|
|
----------
|
|
inputs : list of str
|
|
Specifies the subscripts for summation.
|
|
result : str
|
|
Resulting summation.
|
|
idx_removed : set
|
|
Indices that are removed in the summation
|
|
shapes : sequence of tuple[int], optional
|
|
If given, check also that none of the indices are broadcast dimensions.
|
|
|
|
Returns:
|
|
-------
|
|
type : str or bool
|
|
The type of BLAS call to be used or False if none.
|
|
|
|
Notes:
|
|
-----
|
|
We assume several operations are not efficient such as a transposed
|
|
DDOT, therefore 'ijk,jki->' should prefer einsum. These return the blas
|
|
type appended with "/EINSUM" to differentiate when they can still be done
|
|
with tensordot if required, e.g. when a backend has no einsum.
|
|
|
|
Examples:
|
|
--------
|
|
>>> can_blas(['ij', 'jk'], 'ik', set('j'))
|
|
'GEMM'
|
|
|
|
>>> can_blas(['ijj', 'jk'], 'ik', set('j'))
|
|
False
|
|
|
|
>>> can_blas(['ab', 'cd'], 'abcd', set())
|
|
'OUTER/EINSUM'
|
|
|
|
>>> # looks like GEMM but actually 'j' is broadcast:
|
|
>>> can_blas(['ij', 'jk'], 'ik', set('j'), shapes=[(4, 1), (5, 6)])
|
|
False
|
|
"""
|
|
# Can only do two
|
|
if len(inputs) != 2:
|
|
return False
|
|
|
|
input_left, input_right = inputs
|
|
|
|
for c in set(input_left + input_right):
|
|
# can't deal with repeated indices on same input or more than 2 total
|
|
nl, nr = input_left.count(c), input_right.count(c)
|
|
if (nl > 1) or (nr > 1) or (nl + nr > 2):
|
|
return False
|
|
|
|
# can't do implicit summation or dimension collapse e.g.
|
|
# "ab,bc->c" (implicitly sum over 'a')
|
|
# "ab,ca->ca" (take diagonal of 'a')
|
|
if nl + nr - 1 == int(c in result):
|
|
return False
|
|
|
|
# check for broadcast indices e.g:
|
|
# "ij,jk->ik" (but one of the 'j' dimensions is broadcast up)
|
|
if shapes is not None:
|
|
for c in idx_removed:
|
|
if shapes[0][input_left.find(c)] != shapes[1][input_right.find(c)]:
|
|
return False
|
|
|
|
# Prefer einsum if not removing indices
|
|
# (N.B. tensordot outer faster for large arrays?)
|
|
if len(idx_removed) == 0:
|
|
return "OUTER/EINSUM"
|
|
|
|
# Build a few temporaries
|
|
sets = [set(x) for x in inputs]
|
|
keep_left = sets[0] - idx_removed
|
|
keep_right = sets[1] - idx_removed
|
|
rs = len(idx_removed)
|
|
|
|
# DDOT
|
|
if inputs[0] == inputs[1]:
|
|
return "DOT"
|
|
|
|
# DDOT does not make sense if you have to transpose - prefer einsum
|
|
elif sets[0] == sets[1]:
|
|
return "DOT/EINSUM"
|
|
|
|
# GEMM no transpose
|
|
if input_left[-rs:] == input_right[:rs]:
|
|
return "GEMM"
|
|
|
|
# GEMM transpose both
|
|
elif input_left[:rs] == input_right[-rs:]:
|
|
return "GEMM"
|
|
|
|
# GEMM transpose right
|
|
elif input_left[-rs:] == input_right[-rs:]:
|
|
return "GEMM"
|
|
|
|
# GEMM transpose left
|
|
elif input_left[:rs] == input_right[:rs]:
|
|
return "GEMM"
|
|
|
|
# Einsum is faster than vectordot if we have to copy
|
|
elif (len(keep_left) == 0) or (len(keep_right) == 0):
|
|
return "GEMV/EINSUM"
|
|
|
|
# Conventional tensordot
|
|
else:
|
|
return "TDOT"
|