137 lines
4.2 KiB
Python
137 lines
4.2 KiB
Python
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
import operator
|
|
from itertools import accumulate
|
|
from typing import Optional
|
|
|
|
from ._memref_ops_gen import *
|
|
from ._ods_common import _dispatch_mixed_values, MixedValues
|
|
from .arith import ConstantOp, _is_integer_like_type
|
|
from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType, Operation
|
|
|
|
|
|
def _is_constant_int_like(i):
|
|
return (
|
|
isinstance(i, Value)
|
|
and isinstance(i.owner, Operation)
|
|
and isinstance(i.owner.opview, ConstantOp)
|
|
and _is_integer_like_type(i.type)
|
|
)
|
|
|
|
|
|
def _is_static_int_like(i):
|
|
return (
|
|
isinstance(i, int) and not ShapedType.is_dynamic_size(i)
|
|
) or _is_constant_int_like(i)
|
|
|
|
|
|
def _infer_memref_subview_result_type(
|
|
source_memref_type, offsets, static_sizes, static_strides
|
|
):
|
|
source_strides, source_offset = source_memref_type.get_strides_and_offset()
|
|
# "canonicalize" from tuple|list -> list
|
|
offsets, static_sizes, static_strides, source_strides = map(
|
|
list, (offsets, static_sizes, static_strides, source_strides)
|
|
)
|
|
|
|
if not all(
|
|
all(_is_static_int_like(i) for i in s)
|
|
for s in [
|
|
static_sizes,
|
|
static_strides,
|
|
source_strides,
|
|
]
|
|
):
|
|
raise ValueError(
|
|
"Only inferring from python or mlir integer constant is supported."
|
|
)
|
|
|
|
for s in [offsets, static_sizes, static_strides]:
|
|
for idx, i in enumerate(s):
|
|
if _is_constant_int_like(i):
|
|
s[idx] = i.owner.opview.literal_value
|
|
|
|
if any(not _is_static_int_like(i) for i in offsets + [source_offset]):
|
|
target_offset = ShapedType.get_dynamic_size()
|
|
else:
|
|
target_offset = source_offset
|
|
for offset, target_stride in zip(offsets, source_strides):
|
|
target_offset += offset * target_stride
|
|
|
|
target_strides = []
|
|
for source_stride, static_stride in zip(source_strides, static_strides):
|
|
target_strides.append(source_stride * static_stride)
|
|
|
|
# If default striding then no need to complicate things for downstream ops (e.g., expand_shape).
|
|
default_strides = list(accumulate(static_sizes[1:][::-1], operator.mul))[::-1] + [1]
|
|
if target_strides == default_strides and target_offset == 0:
|
|
layout = None
|
|
else:
|
|
layout = StridedLayoutAttr.get(target_offset, target_strides)
|
|
return (
|
|
offsets,
|
|
static_sizes,
|
|
static_strides,
|
|
MemRefType.get(
|
|
static_sizes,
|
|
source_memref_type.element_type,
|
|
layout,
|
|
source_memref_type.memory_space,
|
|
),
|
|
)
|
|
|
|
|
|
_generated_subview = subview
|
|
|
|
|
|
def subview(
|
|
source: Value,
|
|
offsets: MixedValues,
|
|
sizes: MixedValues,
|
|
strides: MixedValues,
|
|
*,
|
|
result_type: Optional[MemRefType] = None,
|
|
loc=None,
|
|
ip=None,
|
|
):
|
|
if offsets is None:
|
|
offsets = []
|
|
if sizes is None:
|
|
sizes = []
|
|
if strides is None:
|
|
strides = []
|
|
source_strides, source_offset = source.type.get_strides_and_offset()
|
|
if result_type is None and all(
|
|
all(_is_static_int_like(i) for i in s) for s in [sizes, strides, source_strides]
|
|
):
|
|
# If any are arith.constant results then this will canonicalize to python int
|
|
# (which can then be used to fully specify the subview).
|
|
(
|
|
offsets,
|
|
sizes,
|
|
strides,
|
|
result_type,
|
|
) = _infer_memref_subview_result_type(source.type, offsets, sizes, strides)
|
|
elif result_type is None:
|
|
raise ValueError(
|
|
"mixed static/dynamic offset/sizes/strides requires explicit result type."
|
|
)
|
|
|
|
offsets, _packed_offsets, static_offsets = _dispatch_mixed_values(offsets)
|
|
sizes, _packed_sizes, static_sizes = _dispatch_mixed_values(sizes)
|
|
strides, _packed_strides, static_strides = _dispatch_mixed_values(strides)
|
|
|
|
return _generated_subview(
|
|
result_type,
|
|
source,
|
|
offsets,
|
|
sizes,
|
|
strides,
|
|
static_offsets,
|
|
static_sizes,
|
|
static_strides,
|
|
loc=loc,
|
|
ip=ip,
|
|
)
|