84 lines
3.2 KiB
Python
84 lines
3.2 KiB
Python
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
# See https://llvm.org/LICENSE.txt for license information.
|
|
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
|
|
import inspect
|
|
from functools import wraps
|
|
|
|
from ..dialects._ods_common import get_op_result_or_op_results
|
|
from ..ir import Type, InsertionPoint
|
|
|
|
|
|
def op_region_builder(op, op_region, terminator=None):
|
|
def builder_wrapper(body_builder):
|
|
# Add a block with block args having types determined by type hints on the wrapped function.
|
|
if len(op_region.blocks) == 0:
|
|
sig = inspect.signature(body_builder)
|
|
types = [p.annotation for p in sig.parameters.values()]
|
|
if not (
|
|
len(types) == len(sig.parameters)
|
|
and all(isinstance(t, Type) for t in types)
|
|
):
|
|
raise ValueError(
|
|
f"for {body_builder=} either missing a type annotation or type annotation isn't a mlir type: {sig}"
|
|
)
|
|
|
|
op_region.blocks.append(*types)
|
|
|
|
with InsertionPoint(op_region.blocks[0]):
|
|
results = body_builder(*list(op_region.blocks[0].arguments))
|
|
|
|
with InsertionPoint(list(op_region.blocks)[-1]):
|
|
if terminator is not None:
|
|
res = []
|
|
if isinstance(results, (tuple, list)):
|
|
res.extend(results)
|
|
elif results is not None:
|
|
res.append(results)
|
|
terminator(res)
|
|
|
|
return get_op_result_or_op_results(op)
|
|
|
|
return builder_wrapper
|
|
|
|
|
|
def region_op(op_constructor, terminator=None):
|
|
"""Decorator to define an MLIR Op specified as a python function.
|
|
|
|
Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
|
|
active for the current thread (i.e. established in a `with` block).
|
|
|
|
Supports "naked" usage i.e., no parens if no args need to be passed to the Op constructor.
|
|
|
|
When applied as a decorator to a Python function, an entry block will
|
|
be constructed for the Op with types as specified **as type hints on the args of the function**.
|
|
The block arguments will be passed positionally to the Python function.
|
|
|
|
If a terminator is specified then the return from the decorated function will be passed
|
|
to the terminator as the last statement in the entry block. Note, the API for the terminator
|
|
is a (possibly empty) list; terminator accepting single values should be wrapped in a
|
|
`lambda args: term(args[0])`
|
|
|
|
The identifier (name) of the function will become:
|
|
1. A single value result if the Op returns a single value;
|
|
2. An OpResultList (as a list) if the Op returns multiple values;
|
|
3. The Operation if the Op returns no results.
|
|
|
|
See examples in tensor.py and transform.extras.
|
|
"""
|
|
|
|
def op_decorator(*args, **kwargs):
|
|
op = op_constructor(*args, **kwargs)
|
|
op_region = op.regions[0]
|
|
|
|
return op_region_builder(op, op_region, terminator)
|
|
|
|
@wraps(op_decorator)
|
|
def maybe_no_args(*args, **kwargs):
|
|
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
|
return op_decorator()(args[0])
|
|
else:
|
|
return op_decorator(*args, **kwargs)
|
|
|
|
return maybe_no_args
|