# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import operator from itertools import accumulate from typing import Optional from ._memref_ops_gen import * from ._ods_common import _dispatch_mixed_values, MixedValues from .arith import ConstantOp, _is_integer_like_type from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType, Operation def _is_constant_int_like(i): return ( isinstance(i, Value) and isinstance(i.owner, Operation) and isinstance(i.owner.opview, ConstantOp) and _is_integer_like_type(i.type) ) def _is_static_int_like(i): return ( isinstance(i, int) and not ShapedType.is_dynamic_size(i) ) or _is_constant_int_like(i) def _infer_memref_subview_result_type( source_memref_type, offsets, static_sizes, static_strides ): source_strides, source_offset = source_memref_type.get_strides_and_offset() # "canonicalize" from tuple|list -> list offsets, static_sizes, static_strides, source_strides = map( list, (offsets, static_sizes, static_strides, source_strides) ) if not all( all(_is_static_int_like(i) for i in s) for s in [ static_sizes, static_strides, source_strides, ] ): raise ValueError( "Only inferring from python or mlir integer constant is supported." ) for s in [offsets, static_sizes, static_strides]: for idx, i in enumerate(s): if _is_constant_int_like(i): s[idx] = i.owner.opview.literal_value if any(not _is_static_int_like(i) for i in offsets + [source_offset]): target_offset = ShapedType.get_dynamic_size() else: target_offset = source_offset for offset, target_stride in zip(offsets, source_strides): target_offset += offset * target_stride target_strides = [] for source_stride, static_stride in zip(source_strides, static_strides): target_strides.append(source_stride * static_stride) # If default striding then no need to complicate things for downstream ops (e.g., expand_shape). default_strides = list(accumulate(static_sizes[1:][::-1], operator.mul))[::-1] + [1] if target_strides == default_strides and target_offset == 0: layout = None else: layout = StridedLayoutAttr.get(target_offset, target_strides) return ( offsets, static_sizes, static_strides, MemRefType.get( static_sizes, source_memref_type.element_type, layout, source_memref_type.memory_space, ), ) _generated_subview = subview def subview( source: Value, offsets: MixedValues, sizes: MixedValues, strides: MixedValues, *, result_type: Optional[MemRefType] = None, loc=None, ip=None, ): if offsets is None: offsets = [] if sizes is None: sizes = [] if strides is None: strides = [] source_strides, source_offset = source.type.get_strides_and_offset() if result_type is None and all( all(_is_static_int_like(i) for i in s) for s in [sizes, strides, source_strides] ): # If any are arith.constant results then this will canonicalize to python int # (which can then be used to fully specify the subview). ( offsets, sizes, strides, result_type, ) = _infer_memref_subview_result_type(source.type, offsets, sizes, strides) elif result_type is None: raise ValueError( "mixed static/dynamic offset/sizes/strides requires explicit result type." ) offsets, _packed_offsets, static_offsets = _dispatch_mixed_values(offsets) sizes, _packed_sizes, static_sizes = _dispatch_mixed_values(sizes) strides, _packed_strides, static_strides = _dispatch_mixed_values(strides) return _generated_subview( result_type, source, offsets, sizes, strides, static_offsets, static_sizes, static_strides, loc=loc, ip=ip, )