140 lines
4.1 KiB
Python
140 lines
4.1 KiB
Python
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
|
|
from ._scf_ops_gen import *
|
|
from ._scf_ops_gen import _Dialect
|
|
from .arith import constant
|
|
|
|
try:
|
|
from ..ir import *
|
|
from ._ods_common import (
|
|
get_op_result_or_value as _get_op_result_or_value,
|
|
get_op_results_or_values as _get_op_results_or_values,
|
|
_cext as _ods_cext,
|
|
)
|
|
except ImportError as e:
|
|
raise RuntimeError("Error loading imports from extension module") from e
|
|
|
|
from typing import Optional, Sequence, Union
|
|
|
|
|
|
@_ods_cext.register_operation(_Dialect, replace=True)
|
|
class ForOp(ForOp):
|
|
"""Specialization for the SCF for op class."""
|
|
|
|
def __init__(
|
|
self,
|
|
lower_bound,
|
|
upper_bound,
|
|
step,
|
|
iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None,
|
|
*,
|
|
loc=None,
|
|
ip=None,
|
|
):
|
|
"""Creates an SCF `for` operation.
|
|
|
|
- `lower_bound` is the value to use as lower bound of the loop.
|
|
- `upper_bound` is the value to use as upper bound of the loop.
|
|
- `step` is the value to use as loop step.
|
|
- `iter_args` is a list of additional loop-carried arguments or an operation
|
|
producing them as results.
|
|
"""
|
|
if iter_args is None:
|
|
iter_args = []
|
|
iter_args = _get_op_results_or_values(iter_args)
|
|
|
|
results = [arg.type for arg in iter_args]
|
|
super().__init__(
|
|
results, lower_bound, upper_bound, step, iter_args, loc=loc, ip=ip
|
|
)
|
|
self.regions[0].blocks.append(self.operands[0].type, *results)
|
|
|
|
@property
|
|
def body(self):
|
|
"""Returns the body (block) of the loop."""
|
|
return self.regions[0].blocks[0]
|
|
|
|
@property
|
|
def induction_variable(self):
|
|
"""Returns the induction variable of the loop."""
|
|
return self.body.arguments[0]
|
|
|
|
@property
|
|
def inner_iter_args(self):
|
|
"""Returns the loop-carried arguments usable within the loop.
|
|
|
|
To obtain the loop-carried operands, use `iter_args`.
|
|
"""
|
|
return self.body.arguments[1:]
|
|
|
|
|
|
@_ods_cext.register_operation(_Dialect, replace=True)
|
|
class IfOp(IfOp):
|
|
"""Specialization for the SCF if op class."""
|
|
|
|
def __init__(self, cond, results_=None, *, hasElse=False, loc=None, ip=None):
|
|
"""Creates an SCF `if` operation.
|
|
|
|
- `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed.
|
|
- `hasElse` determines whether the if operation has the else branch.
|
|
"""
|
|
if results_ is None:
|
|
results_ = []
|
|
operands = []
|
|
operands.append(cond)
|
|
results = []
|
|
results.extend(results_)
|
|
super().__init__(results, cond, loc=loc, ip=ip)
|
|
self.regions[0].blocks.append(*[])
|
|
if hasElse:
|
|
self.regions[1].blocks.append(*[])
|
|
|
|
@property
|
|
def then_block(self):
|
|
"""Returns the then block of the if operation."""
|
|
return self.regions[0].blocks[0]
|
|
|
|
@property
|
|
def else_block(self):
|
|
"""Returns the else block of the if operation."""
|
|
return self.regions[1].blocks[0]
|
|
|
|
|
|
def for_(
|
|
start,
|
|
stop=None,
|
|
step=None,
|
|
iter_args: Optional[Sequence[Value]] = None,
|
|
*,
|
|
loc=None,
|
|
ip=None,
|
|
):
|
|
if step is None:
|
|
step = 1
|
|
if stop is None:
|
|
stop = start
|
|
start = 0
|
|
params = [start, stop, step]
|
|
for i, p in enumerate(params):
|
|
if isinstance(p, int):
|
|
p = constant(IndexType.get(), p)
|
|
elif isinstance(p, float):
|
|
raise ValueError(f"{p=} must be int.")
|
|
params[i] = p
|
|
|
|
start, stop, step = params
|
|
|
|
for_op = ForOp(start, stop, step, iter_args, loc=loc, ip=ip)
|
|
iv = for_op.induction_variable
|
|
iter_args = tuple(for_op.inner_iter_args)
|
|
with InsertionPoint(for_op.body):
|
|
if len(iter_args) > 1:
|
|
yield iv, iter_args, for_op.results
|
|
elif len(iter_args) == 1:
|
|
yield iv, iter_args[0], for_op.results[0]
|
|
else:
|
|
yield iv
|